diff --git a/MindFlow/applications/cfd/acoustic/README.md b/MindFlow/applications/cfd/acoustic/README.md new file mode 100644 index 0000000000000000000000000000000000000000..181434fb706d6673c00ec54d97254afe0d917338 --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/README.md @@ -0,0 +1,144 @@ +# 2D Acoustic Wave Equation CBS Solver + +## Overview + +The solution of the acoustic wave equation is a core technology in fields such as medical ultrasound and geological exploration. Large-scale acoustic wave equation solving faces challenges in computing power and storage. Acoustic wave equation solvers generally use frequency domain solving algorithms and time domain solving algorithms. The representative of time domain solving algorithms is the Time Domain Finite Difference (TDFD) method, and the frequency domain solving algorithms include Frequency Domain Finite Difference (FDFD), Finite Element Method (FEM), and Convergent Born Series (CBS) iterative method. The CBS method is widely recognized in the engineering and academic communities due to its low memory requirements and the absence of dispersion errors. In particular, [Osnabrugge et al. (2016)](https://linkinghub.elsevier.com/retrieve/pii/S0021999116302595) solved the convergence problem of this method, making the application of the CBS method have broader prospects. The AI model based on the CBS computational structure is a typical representative of the dual-driven paradigm of physics and AI, including [Stanziola et al. (2022)](http://arxiv.org/abs/2212.04948), [Zeng et al. (2023)](http://arxiv.org/abs/2312.15575), etc. + +This case will demonstrate how to call the CBS API provided by MindFlow to solve the 2D acoustic wave equation. + +## Theoretical Background + +### Problem Description + +In the solution of the acoustic wave equation, the velocity field and source information are input parameters, and the output is the spatiotemporal distribution of the wavefield. + +The expression of the 2D acoustic wave equation is as follows: + +| Time Domain Expression | Frequency Domain Expression | +| ----------------------------------------------------- | ------------------------------------------------- | +| $\frac{\partial^2u}{\partial t^2} - c^2 \Delta u = f$ | $-\omega^2 \hat{u} - c^2 \Delta\hat{u} = \hat{f}$ | + +where + +- $u(\bold{x},t) \;\; [L]$ Deformation displacement (pressure divided by density), scalar +- $c(\bold{x}) \;\; [L/T]$ Wave velocity, scalar +- $f(\bold{x},t) \;\; [L/T^2]$ Excitation source (volume distribution force), scalar + +During actual solving, in order to reduce the parameter dimensions, the parameters are usually made dimensionless first, and then the dimensionless equations are solved against the dimensionless parameters, and finally the dimensions of the solutions are recovered. By selecting $\omega$, $\hat{f}$, and $d$ (grid spacing, required to be equal in all directions in this case) to normalize the frequency domain equation, the dimensionless frequency domain equation can be obtained: + +$$ +u^* + c^{*2} \tilde{\Delta} + f^* = 0 +$$ + +where + +- $u^* = \hat{u} \omega^2 / \hat{f}$ Dimensionless deformation displacement +- $c^* = c / (\omega d)$ Dimensionless wave velocity +- $\tilde{\Delta}$ Normalized Laplace operator, i.e., the Laplace operator when the grid spacing is 1 +- $f^*$ Mask marking the source position, with 1 at the source location and 0 at other positions + +### CBS Introduction + +Here is a brief introduction to the theory of the CBS method. For further understanding, please refer to [Osnabrugge et al. (2016)](https://linkinghub.elsevier.com/retrieve/pii/S0021999116302595). + +**Original Born Series** + +First, the frequency domain acoustic wave equation is expressed in the following equivalent form +$$ +k^2 \hat{u} + \Delta \hat{u} +S = 0 +$$ +where $k=\frac{\omega}{c}$ and $S=\frac{\hat{f}}{c^2}$. The non-uniform wave number field $k$ is decomposed into a uniform background potential $k_0$ and a scattering potential $V$: $k^2 = V + k_0^2 + i\epsilon$, where $\epsilon$ is a small quantity that ensures the stability of the iteration, and the final solution of the equation is independent of the specific values of $k_0$ and $\epsilon$. The equation for a single iteration is +$$ +(k_0^2 + i\epsilon) \hat{u} + \Delta \hat{u} = -V \hat{u} - S +$$ +Treating the right-hand side as a known quantity, the solution of this equation is +$$ +\hat{u} = G (V \hat{u} + S) +\qquad +G = \mathcal{F}^{-1} \frac1{|\bold{p}|^2 - k_0^2 - i\epsilon} \mathcal{F} +$$ +Substituting the solution of each iteration back into the right-hand side and performing the next iteration, the iterative expression is obtained +$$ +\hat{u}_{k+1} = GV\hat{u}_k + GS = (1 + GV + GVGV + \cdots)GS +$$ +**Convergent Born Series** + +To ensure convergence, preprocessing and reasonable selection of the value of $epsilon$ are required. Define the preprocessing operator as $\gamma = \frac{i}{\epsilon} V$, and take $\epsilon \geq \max\{|k^2 - k_0^2|\}$. Multiply both sides of the iteration equation by $\gamma$ and rearrange, we get +$$ +\hat{u} = (\gamma GV - \gamma + 1) \hat{u} + \gamma GS +$$ +Let $\gamma GV - \gamma + 1 = M$, then the iteration equation becomes +$$ +\hat{u}_{k+1} = M \hat{u}_k + \gamma GS = (1 + M + M^2 + \cdots) \gamma GS +$$ +In matrix form +$$ +\begin{bmatrix} \hat{u}_k \\ S \end{bmatrix} = +\begin{bmatrix} M & \gamma G \\ 0 & 1 \end{bmatrix}^k +\begin{bmatrix} 0 \\ S \end{bmatrix} +$$ +In the actual program implementation, in order to reduce the number of Fourier transforms, the following equivalent form of the iteration equation is used +$$ +\hat{u}_{k+1} = \hat{u}_k + \gamma [G(V\hat{u}_k + S) - \hat{u}_k] +$$ + +## Case Design + +The content is translated into English as follows: + +- Non-dimensionalization of input parameters; +- Non-dimensionalization of frequency domain 2D acoustic wave equation CBS solver; +- Dimensional restoration of the solution; +- Time-frequency transformation of the solution. + +The core solving process is parallelized for different source locations and different frequency points. Due to the large number of frequency points, it is divided into `n_batches` batches to solve sequentially along the frequency direction. The required input for the case is placed in the `dataset/` directory in the form of files, and the file names are passed in through `config.yaml`. The output results include the solution of the non-dimensionalized equation in the frequency domain `u_star.npy`, the dimensional final solution converted to the time domain `u_time.npy`, and the visualization animation of the time domain solution `wave.gif`. + +## Quick Start + +To facilitate direct verification by users, preset inputs are provided [here](https://download-mindspore.osinfra.cn/mindscience/mindflow/dataset/applications/cfd/acoustic). Please download the data and put them in `./dataset` in the case directory. The data include the velocity field `velocity.npy`, source location list `srclocs.csv`, and source waveform `srcwaves.csv`. Users can modify the input parameters according to the file format. + +### Method 1: Running the `solve_acoustic.py` script + +```shell +python solve_acoustic.py --config_file_path ./configs.yaml --device_id 0 --mode GRAPH +``` + +Where + +`--config_file_path` represents the path of the configuration file, with a default value of `./config.yaml`; + +`--device_id` represents the ID of the computing card used, which can be filled in according to the actual situation, and the most idle one will be automatically selected from all the computing cards by default; + +`--mode` represents the running mode, 'GRAPH' represents static graph mode, and 'PYNATIVE' represents dynamic graph mode. + +### Method 2: Running Jupyter Notebook + +You can use the [Chinese version](./acoustic_CN.ipynb) and the [English version](./acoustic.ipynb) of Jupyter Notebook to run the training and validation code line by line. + +## Result Display + +The evolution of the wave field excited by different source locations for the same velocity model over time is shown in the following figure. + +![wave.gif](images/wave.gif) + +The iterative convergence process of the equation residual is shown in the following figure, with each line representing a frequency point. The number of iterations required to reach the convergence threshold varies for different frequency points, and the number of iterations in the same batch depends on the slowest converging frequency point. + +![errors.png](images/errors.png) + +## Performance + +| Parameter | Ascend | +|:----------------------:|:--------------------------:| +| Hardware | Ascend NPU | +| MindSpore Version | >=2.3.0 | +| Dataset | [Marmousi velocity model](https://en.wikipedia.org/wiki/Marmousi_model) slices, included in the `dataset/` path of the case | +| Number of Parameters | No trainable parameters | +| Solver Parameters | batch_size=300, tol=1e-3, max_iter=10000 | +| Convergence Iterations | batch 0: 1320, batch 1: 560, batch 2: 620, batch 3: 680| +| Solver Speed (ms/iteration) | 500 | + +## Contributors + +gitee id: [WhFanatic](https://gitee.com/WhFanatic) + +email: hainingwang1995@gmail.com \ No newline at end of file diff --git a/MindFlow/applications/cfd/acoustic/README_CN.md b/MindFlow/applications/cfd/acoustic/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..0dce5c291805563f772b230634c8b0e2eb5349a5 --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/README_CN.md @@ -0,0 +1,146 @@ +# 2D 声波方程 CBS 求解 + +## 概述 + +声波方程求解是医疗超声、地质勘探等领域中的核心技术,大规模声波方程求解面临算力和存储的挑战。声波方程求解器一般采用频域求解算法和时域求解算法,时域求解算法的代表是时域有限差分法 (TDFD),频域求解算法包括频域有限差分法 (FDFD)、有限元法 (FEM) 和 CBS (Convergent Born series) 迭代法。CBS 方法由于不引入频散误差,且求解的内存需求低,因此受到工程和学术界的广泛关注。尤其是 [Osnabrugge et al. (2016)](https://linkinghub.elsevier.com/retrieve/pii/S0021999116302595) 解决了该方法的收敛性问题,使得 CBS 方法的应用具有更广阔的前景。基于 CBS 的计算结构所提出的 AI 模型也是物理与 AI 双驱动范式的典型代表,包括 [Stanziola et al. (2022)](http://arxiv.org/abs/2212.04948),[Zeng et al. (2023)](http://arxiv.org/abs/2312.15575) 等。 + +本案例将演示如何调用 MindFlow 提供的 CBS API 实现二维声波方程的求解。 + +## 理论背景 + +### 问题描述 + +声波方程求解中,波速场和震源信息是输入参数,求解输出的是时空分布的波场。 + +二维声波方程的表达式如下 + +| 时域表达式 | 频域表达式 | +| ----------------------------------------------------- | ------------------------------------------------- | +| $\frac{\partial^2u}{\partial t^2} - c^2 \Delta u = f$ | $-\omega^2 \hat{u} - c^2 \Delta\hat{u} = \hat{f}$ | + +其中 + +- $u(\bold{x},t) \;\; [L]$ 变形位移 (压强除以密度),标量 +- $c(\bold{x}) \;\; [L/T]$ 波速,标量 +- $f(\bold{x},t) \;\; [L/T^2]$ 震源激励 (体积分布力),标量 + +实际求解中,为了降低参数维度,一般先将参数无量纲化,然后针对无量纲方程和参数进行求解,最后恢复解的量纲。选取 $\omega$、$\hat{f}$ 和 $d$(网格间距,本案例要求网格在各方向间距相等)对频域方程做无量纲化,可得频域无量纲方程: + +$$ +u^* + c^{*2} \tilde{\Delta} + f^* = 0 +$$ + +其中 + +- $u^* = \hat{u} \omega^2 / \hat{f}$ 为无量纲变形位移 +- $c^* = c / (\omega d)$ 为无量纲波速 +- $\tilde{\Delta}$ 为归一化 Laplace 算子,即网格间距均为 1 时的 Laplace 算子 +- $f^*$ 为标记震源位置的 mask,即在震源作用点为 1,其余位置为 0 + +### CBS 方法介绍 + +此处对 CBS 方法的理论推导作简单介绍,读者可参考 [Osnabrugge et al. (2016)](https://linkinghub.elsevier.com/retrieve/pii/S0021999116302595) 进一步了解。 + +**原始 Born Series** + +首先将频域声波方程表达为以下等价形式 +$$ +k^2 \hat{u} + \Delta \hat{u} +S = 0 +$$ +其中 $k=\omega/c$,$S=\hat{f}/c^2$。将非均匀波数场 $k$ 拆分为均匀背景势 $k_0$ 和散射势 $V$:$k^2 = V + k_0^2 + i\epsilon$,其中 $\epsilon$ 为保持迭代稳定的小量,方程的最终解与 $k_0, \epsilon$ 的具体取值无关。得到单次迭代求解的方程 +$$ +(k_0^2 + i\epsilon) \hat{u} + \Delta \hat{u} = -V \hat{u} - S +$$ +将右端项视为已知量,该方程的解为 +$$ +\hat{u} = G (V \hat{u} + S) +\qquad +G = \mathcal{F}^{-1} \frac1{|\bold{p}|^2 - k_0^2 - i\epsilon} \mathcal{F} +$$ +将每轮迭代的求解结果代回右端项,进行下一轮迭代,得迭代表达式 +$$ +\hat{u}_{k+1} = GV\hat{u}_k + GS = (1 + GV + GVGV + \cdots)GS +$$ +**收敛 Born Series** + +为了保证收敛性,需做一定预处理以及合理选取 $\epsilon$ 的值。定义预处理子 $\gamma = \frac{i}{\epsilon} V$,并取 $\epsilon \ge \max{|k^2 - k_0^2|}$,将迭代的等式两端同乘 $\gamma$ 并整理,可得 +$$ +\hat{u} = (\gamma GV - \gamma + 1) \hat{u} + \gamma GS +$$ +记 $\gamma GV - \gamma + 1 = M$,则迭代式变为 +$$ +\hat{u}_{k+1} = M \hat{u}_k + \gamma GS = (1 + M + M^2 + \cdots) \gamma GS +$$ +矩阵形式 +$$ +\begin{bmatrix} \hat{u}_k \\ S \end{bmatrix} = +\begin{bmatrix} M & \gamma G \\ 0 & 1 \end{bmatrix}^k +\begin{bmatrix} 0 \\ S \end{bmatrix} +$$ +实际程序植入时,为了减少 Fourier 变换的次数,采用以下等价形式的迭代式 +$$ +\hat{u}_{k+1} = \hat{u}_k + \gamma [G(V\hat{u}_k + S) - \hat{u}_k] +$$ + +## 案例设计 + +具体包含以下步骤 + +- 输入参数无量纲化; +- 频域无量纲化 2D 声波方程 CBS 求解; +- 求解结果恢复量纲化; +- 求解结果时频转换。 + +其中核心求解的过程针对不同震源位置和不同频点同时并行求解,由于频点数可能较多,因此沿频率方向分为 `n_batches` 个批次依次求解。 + +案例所需的输入以文件的形式放置于 `dataset/` 中,文件名通过 `config.yaml` 传入。输出的结果为频域无量纲方程的解 `u_star.npy`、转换到时域的有量纲最终解 `u_time.npy`、针对时域解制作的可视化动图 `wave.gif`。 + +## 快速开始 + +为了方便用户直接验证,本案例在本[链接](https://download-mindspore.osinfra.cn/mindscience/mindflow/dataset/applications/cfd/acoustic)中提供了预置的输入数据,请下载所需要的数据集,并保存在 `./dataset` 目录下。数据集包括速度场 `velocity.npy`、震源位置列表 `srclocs.csv`、震源波形 `srcwaves.csv`。用户可仿照输入文件格式自行修改输入参数。 + +### 运行方式一:`solve_acoustic.py` 脚本 + +```shell +python solve_acoustic.py --config_file_path ./configs.yaml --device_id 0 --mode GRAPH +``` + +其中, + +`--config_file_path`表示配置文件的路径,默认值'./config.yaml'; + +`--device_id`表示使用的计算卡编号,可按照实际情况填写,默认从所有计算卡中自动选取最空闲的一张; + +`--mode`表示运行的模式,'GRAPH'表示静态图模式, 'PYNATIVE'表示动态图模式。 + +### 运行方式二:运行 Jupyter Notebook + +您可以使用[中文版](./acoustic_CN.ipynb)和[英文版](./acoustic.ipynb)Jupyter Notebook 逐行运行训练和验证代码。 + +## 结果展示 + +针对同一个速度模型,不同震源位置激发的波场随时间演化过程如下图所示。 + +![wave.gif](images/wave.gif) + +方程残差的迭代收敛过程如下图所示,每根线代表一个频点。不同频点达到收敛阈值所需的迭代次数不同,同一批次的迭代次数取决于收敛最慢的频点。 + +![errors.png](images/errors.png) + +## 性能 + +| 参数 | Ascend | +|:----------------------:|:--------------------------:| +| 硬件资源 | 昇腾 NPU | +| MindSpore版本 | >=2.3.0 | +| 数据集 | [Marmousi 速度模型](https://en.wikipedia.org/wiki/Marmousi_model)切片,包含在案例 `dataset/` 路径中 | +| 参数量 | 无可学习参数 | +| 求解参数 | batch_size=300, tol=1e-3, max_iter=10000 | +| 收敛所需迭代数 | batch 0: 1320, batch 1: 560, batch 2: 620, batch 3: 680| +| 求解速度(ms/iteration) | 500 | + +## 贡献者 + +gitee id: [WhFanatic](https://gitee.com/WhFanatic) + +email: hainingwang1995@gmail.com \ No newline at end of file diff --git a/MindFlow/applications/cfd/acoustic/acoustic.ipynb b/MindFlow/applications/cfd/acoustic/acoustic.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..425cbeb80ebfa93098fe357f83ac7dd893f0887e --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/acoustic.ipynb @@ -0,0 +1,276 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2D acoustic problem\n", + "\n", + "[![Download Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/master/mindflow/zh_cn/cfd_solver/mindspore_acoustic.ipynb) [![Download Sample Codes](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/master/mindflow/zh_cn/cfd_solver/mindspore_acoustic.py) [![View Source Files](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/master/docs/mindflow/docs/source_zh_cn/cfd_solver/acoustic.ipynb)\n", + "\n", + "## Environment Installation\n", + "\n", + "This case requires **MindSpore >= 2.3.0-rc2** version. Please refer to [MindSpore Installation](https://www.mindspore.cn/install) for details.\n", + "\n", + "In addition, you need to install **MindFlow >=0.2.0** version. If it is not installed in the current environment, please follow the instructions below to choose the backend and version for installation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mindflow_version = \"0.3.0\" # update if needed\n", + "\n", + "# Only NPU is supported.\n", + "!pip uninstall -y mindflow-ascend\n", + "!pip install mindflow-ascend==$mindflow_version" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "Solving the acoustic wave equation is a core technology in fields such as medical ultrasound and geological exploration. Large-scale acoustic wave equation solvers face challenges in terms of computational power and storage. Solvers for the wave equation generally use either frequency domain algorithms or time domain algorithms. The representative time domain algorithm is the Time Domain Finite Difference (TDFD) method, while frequency domain algorithms include Frequency Domain Finite Difference (FDFD), Finite Element Method (FEM), and Convergent Born Series (CBS) iterative method. The CBS method, due to its low memory requirement and absence of dispersion error, has gained widespread attention in engineering and academia. In particular, [Osnabrugge et al. (2016)](https://linkinghub.elsevier.com/retrieve/pii/S0021999116302595) have addressed the convergence issue of this method, expanding the application prospects of the CBS method.\n", + "\n", + "This case study will demonstrate how to invoke the CBS API provided by MindFlow to solve the two-dimensional acoustic wave equation." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Problem Description\n", + "\n", + "In solving the acoustic wave equation, the input parameters are the velocity field and source information, and the output is the spatiotemporal distribution of the wave field.\n", + "\n", + "The expression of the two-dimensional acoustic wave equation is as follows\n", + "\n", + "| Time Domain Expression | Frequency Domain Expression |\n", + "| ----------------------------------------------------- | ------------------------------------------------- |\n", + "| $\\frac{\\partial^2u}{\\partial t^2} - c^2 \\Delta u = f$ | $-\\omega^2 \\hat{u} - c^2 \\Delta\\hat{u} = \\hat{f}$ |\n", + "\n", + "Where\n", + "\n", + "- $u(\\bold{x},t) \\;\\; [L]$ is the deformation displacement (pressure divided by density), a scalar.\n", + "- $c(\\bold{x}) \\;\\; [L/T]$ is the wave velocity, a scalar.\n", + "- $f(\\bold{x},t) \\;\\; [L/T^2]$ is the excitation source (volume distributed force), a scalar.\n", + "\n", + "In practical solving, in order to reduce the parameter dimension, the parameters are generally made dimensionless first, and then the dimensionless equations and parameters are solved, and finally the dimensional solutions are restored. By selecting $\\omega$, $\\hat{f}$, and $d$ (grid spacing, which requires equal spacing in all directions) to nondimensionalize the frequency domain equation, we can obtain the dimensionless frequency domain equation:\n", + "\n", + "$$\n", + "u^* + c^{*2} \\tilde{\\Delta} + f^* = 0\n", + "$$\n", + "\n", + "Where\n", + "\n", + "- $u^* = \\hat{u} \\omega^2 / \\hat{f}$ is the dimensionless deformation displacement.\n", + "- $c^* = c / (\\omega d)$ is the dimensionless wave velocity.\n", + "- $\\tilde{\\Delta}$ is the normalized Laplace operator, which is the Laplace operator when the grid spacing is 1.\n", + "- $f^*$ the mask that marks the source position, with a value of 1 at the source and 0 at other positions.\n", + "\n", + "The `src` package in this case can be downloaded at [src](https://gitee.com/mindspore/mindscience/tree/master/MindFlow/applications/cfd/acoustic/src)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import scipy\n", + "import pandas as pd\n", + "import mindspore as ms\n", + "from mindspore import Tensor\n", + "\n", + "from mindflow.utils import load_yaml_config\n", + "\n", + "from cbs.cbs import CBS\n", + "from src import visual\n", + "from solve_acoustic import solve_cbs" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define input parameters and output sampling method\n", + "\n", + "The required inputs for this case are dimensional 2D velocity field, source location list, and source waveform. The input file name is specified in the [config.yaml](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/cfd/acoustic/config.yaml) file. For user convenience, pre-set inputs are provided [here](https://download-mindspore.osinfra.cn/mindscience/mindflow/dataset/applications/cfd/acoustic). Please download the data and put them in `./dataset` in the case directory. The data include the velocity field `velocity.npy`, source location list `srclocs.csv`, and source waveform `srcwaves.csv`. Users can modify the input parameters based on the input file format.\n", + "\n", + "The output is a spatiotemporal distribution of the wavefield. To specify how the output is sampled in time and frequency, parameters such as `dt` and `nt` need to be specified in the [config.yaml](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/cfd/acoustic/config.yaml) file.\n", + "\n", + "Since the sampling rate of the input source waveform in time may differ from the required output, interpolation needs to be performed." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "ms.set_context(device_target='Ascend', device_id=0, mode=ms.GRAPH_MODE)\n", + "\n", + "config = load_yaml_config('config.yaml')\n", + "\n", + "data_config = config['data']\n", + "solve_config = config['solve']\n", + "summary_config = config['summary']\n", + "\n", + "# read time & frequency points\n", + "dt = solve_config['dt']\n", + "nt = solve_config['nt']\n", + "ts = np.arange(nt) * dt\n", + "omegas_all = np.fft.rfftfreq(nt) * (2 * np.pi / dt)\n", + "\n", + "# read source locations\n", + "df = pd.read_csv(os.path.join(data_config['root_dir'], data_config['source_locations']), index_col=0)\n", + "slocs = df[['y', 'x']].values # shape (ns, 2)\n", + "\n", + "# read & interp source wave\n", + "df = pd.read_csv(os.path.join(data_config['root_dir'], data_config['source_wave']))\n", + "inter_func = scipy.interpolate.interp1d(df.t, df.f, bounds_error=False, fill_value=0)\n", + "src_waves = inter_func(ts) # shape (nt)\n", + "src_amplitudes = np.fft.rfft(src_waves) # shape (nt//2+1)\n", + "\n", + "# read velocity array\n", + "velo = np.load(os.path.join(data_config['root_dir'], data_config['velocity_field']))\n", + "nz, nx = velo.shape\n", + "dx = data_config['velocity_dx']" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Select desired frequency points\n", + "\n", + "With the output sampling method determined, all the desired frequency points are in turn determined. However, in order to reduce computational load, it is also possible to select only a portion of the frequency points for calculation, while obtaining the remaining frequency points through interpolation. The specific frequency point downsampling method is specified by the `downsample_mode` and `downsample_rate` in the [config.yaml](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/cfd/acoustic/config.yaml) file. The default is no downsampling, which means solving all frequency points except $\\omega=0$." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# select omegas\n", + "no = len(omegas_all) // solve_config['downsample_rate']\n", + "\n", + "if solve_config['downsample_mode'] == 'exp':\n", + " omegas_sel = np.exp(np.linspace(np.log(omegas_all[1]), np.log(omegas_all[-1]), no))\n", + "elif solve_config['downsample_mode'] == 'square':\n", + " omegas_sel = np.linspace(omegas_all[1]**.5, omegas_all[-1]**.5, no)**2\n", + "else:\n", + " omegas_sel = np.linspace(omegas_all[1], omegas_all[-1], no)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Perform simulation\n", + "\n", + "Define the relevant arrays as Tensors, call `solve_cbs()`, and execute the solution on the NPU. Due to memory limitations, the solution process is executed in batches in the frequency domain. The number of batches is specified by the user in `config.yaml` and does not need to be divisible by the number of frequency points (allowing the size of the last batch to be different from the other batches). After the solution is completed, the frequency domain solution results will be saved to the file `u_star.npy`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### send to NPU and perform computation\n", + "os.makedirs(summary_config['root_dir'], exist_ok=True)\n", + "velo = Tensor(velo, dtype=ms.float32, const_arg=True)\n", + "cbs = CBS((nz, nx), remove_pml=False)\n", + "\n", + "ur, ui = solve_cbs(cbs, velo, slocs, omegas_sel, dx=dx, n_batches=solve_config['n_batches']) # shape (ns, no, len(receiver_zs), nx)\n", + "\n", + "u_star = np.squeeze(ur.numpy() + 1j * ui.numpy()) # shape (ns, no, len(krs), nx)\n", + "np.save(os.path.join(summary_config['root_dir'], 'u_star.npy'), u_star)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Post-processing\n", + "\n", + "CBS solves the dimensionless frequency domain equation, but downstream tasks often require observing the evolution process of dimensional wavefields in the time domain. Therefore, the final solution is restored to dimensional and converted back to the time domain. The restoration method is given by $\\hat{u} = u^* hat{f} / \\omega^2$. If downsampling is performed on the frequency points in the \"Select desired frequency points\" step, interpolation along the frequency direction is required here to restore the solutions for all frequency points. Then, perform a Fourier inverse transform on the dimensional frequency domain wavefield $\\hat{u}$ to obtain the time domain wavefield $u$. Save the time domain wavefield to the file `u_time.npy`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# recover dimension and interpolate to full frequency domain\n", + "u_star /= omegas_sel.reshape(-1, 1, 1)**2\n", + "u_star = scipy.interpolate.interp1d(omegas_sel, u_star, axis=1, kind='cubic', bounds_error=False, fill_value=0)(omegas_all)\n", + "u_star *= src_amplitudes.reshape(-1, 1, 1)\n", + "\n", + "# transform to time domain\n", + "u_time = np.fft.irfft(u_star, axis=1)\n", + "np.save(os.path.join(summary_config['root_dir'], 'u_time.npy'), u_time)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, read the time-domain wave field and visualize." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# visualize the result\n", + "u_time = np.load(os.path.join(summary_config['root_dir'], 'u_time.npy'))\n", + "visual.anim(velo.numpy(), u_time, ts, os.path.join(summary_config['root_dir'], 'wave.gif'))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0]" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/MindFlow/applications/cfd/acoustic/acoustic_CN.ipynb b/MindFlow/applications/cfd/acoustic/acoustic_CN.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..1b631abc53ed49cabd3cc0113747c656da9a9dd2 --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/acoustic_CN.ipynb @@ -0,0 +1,276 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 二维声波问题\n", + "\n", + "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/master/mindflow/zh_cn/cfd_solver/mindspore_acoustic.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/master/mindflow/zh_cn/cfd_solver/mindspore_acoustic.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/master/docs/mindflow/docs/source_zh_cn/cfd_solver/acoustic.ipynb)\n", + "\n", + "## 环境安装\n", + "\n", + "本案例要求 **MindSpore >= 2.3.0-rc2** 版本。具体请查看 [MindSpore安装](https://www.mindspore.cn/install)。\n", + "\n", + "此外,你需要安装 **MindFlow >=0.2.0** 版本。如果当前环境还没有安装,请按照下列方式选择后端和版本进行安装。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mindflow_version = \"0.3.0\" # update if needed\n", + "\n", + "# Only NPU is supported.\n", + "!pip uninstall -y mindflow-ascend\n", + "!pip install mindflow-ascend==$mindflow_version" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 概述\n", + "\n", + "声波方程求解是医疗超声、地质勘探等领域中的核心技术,大规模声波方程求解面临算力和存储的挑战。声波方程求解器一般采用频域求解算法和时域求解算法,时域求解算法的代表是时域有限差分法 (TDFD),频域求解算法包括频域有限差分法 (FDFD)、有限元法 (FEM) 和 CBS (Convergent Born series) 迭代法。CBS 方法由于不引入频散误差,且求解的内存需求低,因此受到工程和学术界的广泛关注。尤其是 [Osnabrugge et al. (2016)](https://linkinghub.elsevier.com/retrieve/pii/S0021999116302595) 解决了该方法的收敛性问题,使得 CBS 方法的应用具有更广阔的前景。\n", + "\n", + "本案例将演示如何调用 MindFlow 提供的 CBS API 实现二维声波方程的求解。" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 问题描述\n", + "\n", + "声波方程求解中,波速场和震源信息是输入参数,求解输出的是时空分布的波场。\n", + "\n", + "二维声波方程的表达式如下\n", + "\n", + "| 时域表达式 | 频域表达式 |\n", + "| ----------------------------------------------------- | ------------------------------------------------- |\n", + "| $\\frac{\\partial^2u}{\\partial t^2} - c^2 \\Delta u = f$ | $-\\omega^2 \\hat{u} - c^2 \\Delta\\hat{u} = \\hat{f}$ |\n", + "\n", + "其中\n", + "\n", + "- $u(\\bold{x},t) \\;\\; [L]$ 变形位移 (压强除以密度),标量\n", + "- $c(\\bold{x}) \\;\\; [L/T]$ 波速,标量\n", + "- $f(\\bold{x},t) \\;\\; [L/T^2]$ 震源激励 (体积分布力),标量\n", + "\n", + "实际求解中,为了降低参数维度,一般先将参数无量纲化,然后针对无量纲方程和参数进行求解,最后恢复解的量纲。选取 $\\omega$、$\\hat{f}$ 和 $d$(网格间距,本案例要求网格在各方向间距相等)对频域方程做无量纲化,可得频域无量纲方程:\n", + "\n", + "$$\n", + "u^* + c^{*2} \\tilde{\\Delta} + f^* = 0\n", + "$$\n", + "\n", + "其中\n", + "\n", + "- $u^* = \\hat{u} \\omega^2 / \\hat{f}$ 为无量纲变形位移\n", + "- $c^* = c / (\\omega d)$ 为无量纲波速\n", + "- $\\tilde{\\Delta}$ 为归一化 Laplace 算子,即网格间距均为 1 时的 Laplace 算子\n", + "- $f^*$ 为标记震源位置的 mask,即在震源作用点为 1,其余位置为 0\n", + "\n", + "本案例中 `src` 包可以在 [src](https://gitee.com/mindspore/mindscience/tree/master/MindFlow/applications/cfd/acoustic/src) 下载。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import scipy\n", + "import pandas as pd\n", + "import mindspore as ms\n", + "from mindspore import Tensor\n", + "\n", + "from mindflow.utils import load_yaml_config\n", + "\n", + "from cbs.cbs import CBS\n", + "from src import visual\n", + "from solve_acoustic import solve_cbs" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 定义输入参数和输出采样方式\n", + "\n", + "本案例所需的输入为有量纲 2D 速度场、震源位置列表、震源波形,在文件 [config.yaml](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/cfd/acoustic/config.yaml) 中指定输入文件名。为了方便用户直接验证,本案例在本[链接](https://download-mindspore.osinfra.cn/mindscience/mindflow/dataset/applications/cfd/acoustic)中提供了预置的输入数据,请下载所需要的数据集,并保存在 `./dataset` 目录下。数据集包括速度场 `velocity.npy`、震源位置列表 `srclocs.csv`、震源波形 `srcwaves.csv`。用户可仿照输入文件格式自行修改输入参数。\n", + "\n", + "输出为时空分布的波场,为了明确输出如何在时间和频率上采样,需在 [config.yaml](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/cfd/acoustic/config.yaml) 文件中指定 `dt`, `nt` 等参数。\n", + "\n", + "由于输入的震源波形在时间上的采样率可能与输出所要求的不一致,因此需对其进行插值。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "ms.set_context(device_target='Ascend', device_id=0, mode=ms.GRAPH_MODE)\n", + "\n", + "config = load_yaml_config('config.yaml')\n", + "\n", + "data_config = config['data']\n", + "solve_config = config['solve']\n", + "summary_config = config['summary']\n", + "\n", + "# read time & frequency points\n", + "dt = solve_config['dt']\n", + "nt = solve_config['nt']\n", + "ts = np.arange(nt) * dt\n", + "omegas_all = np.fft.rfftfreq(nt) * (2 * np.pi / dt)\n", + "\n", + "# read source locations\n", + "df = pd.read_csv(os.path.join(data_config['root_dir'], data_config['source_locations']), index_col=0)\n", + "slocs = df[['y', 'x']].values # shape (ns, 2)\n", + "\n", + "# read & interp source wave\n", + "df = pd.read_csv(os.path.join(data_config['root_dir'], data_config['source_wave']))\n", + "inter_func = scipy.interpolate.interp1d(df.t, df.f, bounds_error=False, fill_value=0)\n", + "src_waves = inter_func(ts) # shape (nt)\n", + "src_amplitudes = np.fft.rfft(src_waves) # shape (nt//2+1)\n", + "\n", + "# read velocity array\n", + "velo = np.load(os.path.join(data_config['root_dir'], data_config['velocity_field']))\n", + "nz, nx = velo.shape\n", + "dx = data_config['velocity_dx']" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 选取待求频点\n", + "\n", + "确定了输出采样方式即确定了所有待求频点。但为了减少计算量,也可以只选择部分频点进行求解,其余频点通过插值获得。具体的频点降采样方式由 [config.yaml](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/cfd/acoustic/config.yaml) 文件中的 `downsample_mode`, `downsample_rate` 指定。默认为不做降采样,即求解除 $\\omega=0$ 之外的所有频点。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# select omegas\n", + "no = len(omegas_all) // solve_config['downsample_rate']\n", + "\n", + "if solve_config['downsample_mode'] == 'exp':\n", + " omegas_sel = np.exp(np.linspace(np.log(omegas_all[1]), np.log(omegas_all[-1]), no))\n", + "elif solve_config['downsample_mode'] == 'square':\n", + " omegas_sel = np.linspace(omegas_all[1]**.5, omegas_all[-1]**.5, no)**2\n", + "else:\n", + " omegas_sel = np.linspace(omegas_all[1], omegas_all[-1], no)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 执行仿真\n", + "\n", + "将相关数组定义为 Tensor,调用 `solve_cbs()`,在 NPU 执行求解。由于显存限制,求解过程在频点维度分批执行,batch 数量由用户在 `config.yaml` 中指定,不要求整除频点数(允许最后一个 batch 的 size 与其他 batch 不一致)。求解完成后,会保存频域求解结果到文件 `u_star.npy`。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### send to NPU and perform computation\n", + "os.makedirs(summary_config['root_dir'], exist_ok=True)\n", + "velo = Tensor(velo, dtype=ms.float32, const_arg=True)\n", + "cbs = CBS((nz, nx), remove_pml=False)\n", + "\n", + "ur, ui = solve_cbs(cbs, velo, slocs, omegas_sel, dx=dx, n_batches=solve_config['n_batches']) # shape (ns, no, len(receiver_zs), nx)\n", + "\n", + "u_star = np.squeeze(ur.numpy() + 1j * ui.numpy()) # shape (ns, no, len(krs), nx)\n", + "np.save(os.path.join(summary_config['root_dir'], 'u_star.npy'), u_star)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 仿真结果后处理\n", + "\n", + "CBS 求解的是无量纲化的频域方程,但下游任务通常希望在时域观察有量纲波场的演化过程,因此最后将求解结果恢复量纲化并转回时域。恢复量纲的方式为 $\\hat{u} = u^* \\hat{f} / \\omega^2$,若在前述的“选取待求频点”步骤中对频点做了降采样,则在此处需沿频率方向插值恢复所有频点的解。然后对有量纲的频域波场 $\\hat{u}$ 做 Fourier 反变换得到时域波场 $u$。将时域波场保存至文件 `u_time.npy`。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# recover dimension and interpolate to full frequency domain\n", + "u_star /= omegas_sel.reshape(-1, 1, 1)**2\n", + "u_star = scipy.interpolate.interp1d(omegas_sel, u_star, axis=1, kind='cubic', bounds_error=False, fill_value=0)(omegas_all)\n", + "u_star *= src_amplitudes.reshape(-1, 1, 1)\n", + "\n", + "# transform to time domain\n", + "u_time = np.fft.irfft(u_star, axis=1)\n", + "np.save(os.path.join(summary_config['root_dir'], 'u_time.npy'), u_time)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "最后,读取时域波场并可视化。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# visualize the result\n", + "u_time = np.load(os.path.join(summary_config['root_dir'], 'u_time.npy'))\n", + "visual.anim(velo.numpy(), u_time, ts, os.path.join(summary_config['root_dir'], 'wave.gif'))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0]" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/MindFlow/applications/cfd/acoustic/cbs/__init__.py b/MindFlow/applications/cfd/acoustic/cbs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d1443867989bc930cbf052499c248607032bcd35 --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/cbs/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/MindFlow/applications/cfd/acoustic/cbs/cbs.py b/MindFlow/applications/cfd/acoustic/cbs/cbs.py new file mode 100644 index 0000000000000000000000000000000000000000..5706ae37a731318da617af0da922065f1ae0c95a --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/cbs/cbs.py @@ -0,0 +1,266 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The CBS (convergen Born series) API""" +from math import factorial +import numpy as np +import mindspore as ms +from mindspore import Tensor, nn, ops, numpy as mnp, lazy_inline + +from .dft import MyDFTn, MyiDFTn + + +class CBSBlock(nn.Cell): + ''' The computation procedures for each iteration in CBS ''' + @lazy_inline + def __init__(self, shape): + ''' + No trainable parameters, but the dft cells needs initialization + Args: + shape: tuple of int, only the spatial shape, not including the batch and channel dimensions + ''' + super().__init__() + self.dft_cell = MyDFTn(shape) + self.idft_cell = MyiDFTn(shape) + + # Scattering potential calculation for real and imaginary parts + def op_v(self, ur, ui, vr, vi): + wr = ur * vr - ui * vi + wi = ur * vi + ui * vr + return wr, wi + + # Vectorized Helmholtz Green function for real and imaginary parts + def op_g(self, ur, ui, gr, gi): + fur, fui = self.dft_cell(ur, ui) + gur = gr * fur - gi * fui + gui = gi * fur + gr * fui + wr, wi = self.idft_cell(gur, gui) + return wr, wi + + # Vectorized Born iteration for real and imaginary parts + def construct(self, ur, ui, vr, vi, gr, gi, rhs, eps): + ''' run one iteration and return the incremental ''' + vur, vui = self.op_v(ur, ui, vr, vi) + gvr, gvi = self.op_g(vur + rhs, vui, gr, gi) + vgr, vgi = self.op_v(gvr - ur, gvi - ui, vr, vi) + + # eps > 0: Convergent Born series; eps == 0: Original Born Series + cond = ops.broadcast_to(eps, ur.shape) > 0 + dur = ops.select(cond, -vgi / (eps + 1e-8), gvr - ur) # '* (-1.)' comes from imag part multiplying i/eps + dui = ops.select(cond, vgr / (eps + 1e-8), gvi - ui) + + return ops.stack([dur, dui]) # return a single Tensor for compatibility with nn.SequentialCell + +class CBS(nn.Cell): + ''' The CBS cell for solving 2D acoustic equation ''' + def __init__(self, + shape, + n_iter=20, + pml_size=60, + alpha=1.0, + rampup=12, + remove_pml=True, + epsilon=None, + ): + """Configurations of the CBS solver + + Args: + shape (tuple[int]): only the spatial shape, not including the batch and channel dimensions + n_iter (int, optional): number of iterations in a single call. Defaults to 20. + pml_size (int, optional): number of grid layers to pad on each boundary for the wave to attenuate. + Defaults to 60. + alpha (float, optional): the strength of wave attenuation in PML layers. Defaults to 1.0. + rampup (int, optional): the smoothness of transition from interior domain to PML layers. Defaults to 12. + remove_pml (bool, optional): whether to remove the PML layers for the output. Defaults to True. + epsilon (float, optional): the small value to stabilize the iteration. + Defaults to None, calculating epsilon automatically. + """ + super().__init__() + + self.n_iter = n_iter + self.pml_size = pml_size + self.alpha = alpha + self.rampup = rampup + self.remove_pml = remove_pml + self.epsilon = epsilon + + shape_padded = tuple(n + 2 * pml_size for n in shape) + + dxs = (1.0, 1.0) + p_sq = sum(np.meshgrid( + *[np.fft.fftfreq(n, d)**2 for n, d in zip(shape_padded, dxs)], + indexing="ij")) * (2 * np.pi)**2 + self.p_sq = Tensor(p_sq, dtype=ms.float32, const_arg=True) + + pml_mask = 1 - np.pad(np.ones(shape), pml_size) + self.pml_mask = Tensor(pml_mask, dtype=ms.float32, const_arg=True) + + self.cbs_block = CBSBlock(shape_padded) + + def cbs_params(self, c_star, f_star): + ''' compute constant variables for CBS iteration ''' + pml_size = self.pml_size + nz, nx = c_star.shape[-2:] + dxs = (1.0, 1.0) + omg = 1.0 + + # source field + rhs = ops.pad(f_star / c_star**2, [pml_size] * 4) # (batch, 1, nz_padded, nx_padded) + + # homogeneous k field + k_max = omg / ops.amin(c_star, axis=(-2, -1), keepdims=True) + k_min = omg / ops.amax(c_star, axis=(-2, -1), keepdims=True) + k0 = ops.sqrt(0.5 * (k_max**2 + k_min**2)) # (batch, 1, 1, 1) + + # heterogeneous k field + ksq_r, ksq_i = self.cbs_pml( + (nz, nx), dxs, k_max, pml_size, self.alpha, self.rampup) # (batch, 1, nz_padded, nx_padded) + + ksq_r = ksq_r * self.pml_mask + ops.pad((omg / c_star)**2, [pml_size] * 4) * (1 - self.pml_mask) + ksq_i = ksq_i * self.pml_mask + + eps = ops.amax((ksq_r - k0**2)**2 + ksq_i**2, axis=(-2, -1), keepdims=True)**.5 # (batch, 1, 1, 1) + + # if epsilon given by user, use original BS instead of CBS + if isinstance(self.epsilon, (float, int)): + eps = self.epsilon * ops.ones_like(eps) + + # field variables needed by operator V & G + vr = ksq_r - k0**2 # (batch, 1, nz_padded, nx_padded) + vi = ksq_i - eps # (batch, 1, nz_padded, nx_padded) + gr = 1. / ((self.p_sq - k0**2)**2 + eps**2) * (self.p_sq - k0**2) # (batch, 1, nz_padded, nx_padded) + gi = 1. / ((self.p_sq - k0**2)**2 + eps**2) * eps # (batch, 1, nz_padded, nx_padded) + + return vr, vi, gr, gi, rhs, eps * (self.epsilon is None) + + @staticmethod + def cbs_pml(shape, dxs, k0, pml_size, alpha, rampup): + ''' construct the heterogeneous k field with PML BC embedded ''' + shape_padded = tuple(n + 2 * pml_size for n in shape) + + def num(x): + num_real = (alpha ** 2) * (rampup - alpha * x) * ((alpha * x) ** (rampup - 1)) + num_imag = (alpha ** 2) * (2 * k0 * x) * ((alpha * x) ** (rampup - 1)) + return num_real, num_imag + + def den(x): + return sum([(alpha * x) ** i / float(factorial(i)) for i in range(rampup + 1)]) * factorial(rampup) + + def transform_fun(x): + num_real, num_imag = num(x) + den_x = den(x) + transform_real, transform_imag = num_real / den_x, num_imag / den_x + return transform_real, transform_imag + + diff = ops.stack(mnp.meshgrid( + *[((ops.abs(mnp.linspace(1 - n, n - 1, n)) - n) / 2 + pml_size) * d for n, d in zip(shape_padded, dxs)], + indexing="ij"), axis=0) + + diff *= (diff > 0).astype(ms.float32) / 4. + + dist = ops.norm(diff, dim=0) + k_k0_real, k_k0_imag = transform_fun(dist) + ksq_r = k_k0_real + k0 ** 2 + ksq_i = k_k0_imag + + return ksq_r, ksq_i + + def construct(self, c_star, f_star, ur_init=None, ui_init=None): + ''' + Run the solver to solve non-dimensionalized 2D acoustic equation for given c* and f* + Args: + c_star: float (batch_size, 1, nz, nx), the non-dimensionalized velocity field + f_star: float (batch_size, 1, nz, nx), the mask marking out the source locations + ur_init, ui_init: float (batch_size, 1, NZ, NX), initial wave field for iteration, real & imag parts. + If remove_pml is True, NZ = nz, NX = nx, otherwise NZ = nz + 2 * pml_size, NX = nx + 2 * pml_size. + Default is None, which means initialize from 0. + ''' + vr, vi, gr, gi, rhs, eps = self.cbs_params(c_star, f_star) + + n0 = self.remove_pml * self.pml_size + n1 = (ur_init is None or self.remove_pml) * self.pml_size + n2 = (ui_init is None or self.remove_pml) * self.pml_size + + # construct initial field + if ur_init is None: + ur_init = ops.zeros_like(c_star, dtype=ms.float32) # (batch, 1, nz, nx) + if ui_init is None: + ui_init = ops.zeros_like(c_star, dtype=ms.float32) # (batch, 1, nz, nx) + + # pad initial field + ur = ops.pad(ur_init, padding=[n1] * 4, value=0) # note: better padding (with gradual damping) can be applied + ui = ops.pad(ui_init, padding=[n2] * 4, value=0) # (batch, 1, nz_padded, nx_padded) + + # start iteration + errs_list = [] + + for _ in range(self.n_iter): + dur, dui = self.cbs_block(ur, ui, vr, vi, gr, gi, rhs, eps) + ur += dur + ui += dui + + # calculate iteration residual + errs = (ops.sum(dur**2 + dui**2, dim=(-2, -1)) / ops.sum(ur**2 + ui**2, dim=(-2, -1)))**.5 + errs_list.append(errs) + + # remove pml layer + nz, nx = ur.shape[-2:] + ur = ur[..., n0:nz - n0, n0:nx - n0] + ui = ui[..., n0:nz - n0, n0:nx - n0] + ui *= -1. + # note: the conjugate here is because we define Fourier modes differently to JAX in that the frequencies + # are opposite, leading to opposite attenuation in PML, and finally the conjugation in results + + return ur, ui, errs_list + + def solve(self, + c_star, + f_star, + ur_init=None, + ui_init=None, + tol=1e-3, + max_iter=10000, + remove_pml=True, + print_info=True, + ): + """A convenient method for solving the equation to a given tolerance + + Args: + tol (float, optional): the tolerance for the relative error. Defaults to 1e-3. + """ + msg = 'PML layers cannot be removed during iteration, but can be removed for the final result' + assert not self.remove_pml, msg + + ur, ui, errs_list = self(c_star, f_star, ur_init, ui_init) + + for ep in range(max_iter // self.n_iter): + err_max = float(errs_list[-1].max()) + err_min = float(errs_list[-1].min()) + err_ave = float(errs_list[-1].mean()) + + if print_info: + print(f'step {(ep + 1) * self.n_iter}, max error {err_max:.6f}', end=', ') + print(f'min error {err_min:.6f}, mean error {err_ave:.6f}') + + if err_max < tol: + break + + ur, ui, errs = self(c_star, f_star, ur, -ui) + errs_list += errs + + if remove_pml and self.pml_size: + ur = ur[..., self.pml_size:-self.pml_size, self.pml_size:-self.pml_size] + ui = ui[..., self.pml_size:-self.pml_size, self.pml_size:-self.pml_size] + + return ur, ui, errs_list diff --git a/MindFlow/applications/cfd/acoustic/cbs/dft.py b/MindFlow/applications/cfd/acoustic/cbs/dft.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f5109a4558c9ab748f36878b763d94950ba101 --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/cbs/dft.py @@ -0,0 +1,100 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +''' provide complex dft based on the real dft API in mindflow.dft ''' +import numpy as np +import mindspore as ms +from mindspore import nn, ops, numpy as mnp +from mindflow.cell.neural_operators.dft import dft1, dft2, dft3 + + +class MyDFTn(nn.Cell): + def __init__(self, shape): + super().__init__() + assert len(shape) in (1, 2, 3), 'only ndim 1, 2, 3 supported' + + n = shape[-1] + ndim = len(shape) + modes = tuple([_ // 2 for _ in shape[-ndim:-1]] + [n // 2 + 1]) if ndim > 1 else n // 2 + 1 + + self.shape = tuple(shape) + self.dft_cell = { + 1: dft1, + 2: dft2, + 3: dft3, + }[ndim](shape, modes) + + # use mask to assemble slices of Tensors, avoiding dynamic shape + # bug note: for unknown reasons, GRAPH_MODE cannot work with mask Tensors allocated using ops.ones() + mask_x0 = np.ones(n//2 + 1) + mask_xm = np.ones(n//2 + 1) + mask_y0 = np.ones(shape) + mask_z0 = np.ones(shape) + mask_x0[0] = 0 + mask_xm[-1] = 0 + if ndim > 1: + mask_y0[..., 0, :] = 0 + if ndim > 2: + mask_z0[..., 0, :, :] = 0 + + self.mask_x0 = ms.Tensor(mask_x0, dtype=ms.float32, const_arg=True) + self.mask_xm = ms.Tensor(mask_xm, dtype=ms.float32, const_arg=True) + self.mask_y0 = ms.Tensor(mask_y0, dtype=ms.float32, const_arg=True) + self.mask_z0 = ms.Tensor(mask_z0, dtype=ms.float32, const_arg=True) + + def construct(self, ar, ai): + shape = tuple(self.shape) + n = shape[-1] + ndim = len(shape) + scale = float(np.prod(shape) ** .5) + + assert ai is None or ar.shape == ai.shape + assert ar.shape[-ndim:] == shape + + brr, bri = self.dft_cell((ar, ar * 0)) + + # n-D Fourier transform with last axis being real-transformed, output dimension (..., m, n//2+1) + if ai is None: + return brr * scale, bri * scale + + # n-D complex Fourier transform, output dimension (..., m, n) + # call dft for real & imag parts separately and then assemble + bir, bii = self.dft_cell((ai, ai * 0)) + + br_half1 = ops.pad((brr - bii) * self.mask_xm, [0, n//2 - 1]) + bi_half1 = ops.pad((bri + bir) * self.mask_xm, [0, n//2 - 1]) + # bug note: mnp.roll() & mnp.flip are ok, but ops.roll() only supports GPU, ops.flip() has bug in MS2.4.0 + br_half2 = mnp.roll(mnp.flip(ops.pad((brr + bii) * self.mask_x0, [n//2 - 1, 0]), axis=-1), n//2, axis=-1) + bi_half2 = mnp.roll(mnp.flip(ops.pad((bir - bri) * self.mask_x0, [n//2 - 1, 0]), axis=-1), n//2, axis=-1) + if ndim > 1: + br_half2 = br_half2 * (1 - self.mask_y0) + mnp.roll(mnp.flip(br_half2 * self.mask_y0, axis=-2), 1, axis=-2) + bi_half2 = bi_half2 * (1 - self.mask_y0) + mnp.roll(mnp.flip(bi_half2 * self.mask_y0, axis=-2), 1, axis=-2) + if ndim > 2: + br_half2 = br_half2 * (1 - self.mask_z0) + mnp.roll(mnp.flip(br_half2 * self.mask_z0, axis=-3), 1, axis=-3) + bi_half2 = bi_half2 * (1 - self.mask_z0) + mnp.roll(mnp.flip(bi_half2 * self.mask_z0, axis=-3), 1, axis=-3) + + br = br_half1 + br_half2 + bi = bi_half1 + bi_half2 + + return br * scale, bi * scale + +class MyiDFTn(MyDFTn): + def __init__(self, shape): + super().__init__(shape) + + def construct(self, ar, ai): + ndim = len(self.shape) + scale = float(np.prod(ar.shape[-ndim:])) + br, bi = super().construct(ar, -ai) + return br / scale, -bi / scale diff --git a/MindFlow/applications/cfd/acoustic/config.yaml b/MindFlow/applications/cfd/acoustic/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..85b96730a7fe517351740aeddd41ed4d7e22bfce --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/config.yaml @@ -0,0 +1,16 @@ +data: + root_dir: 'dataset' + velocity_field: 'velocity.npy' + velocity_dx: 16. # grid interval of the velocity matrix + source_wave: 'srcwaves.csv' + source_locations: 'srclocs.csv' + +solve: + dt: 0.02 # time interval of the output + nt: 300 # number of time points of the output, must be even (required by rfft) + downsample_mode: 'linear' # way to downsample the frequency points, options: linear, exp, square + downsample_rate: 1 # only 1/downsample_rate frequency points will be solved + n_batches: 4 # the number of batches for frequencies to be diveded into + +summary: + root_dir: 'results' diff --git a/MindFlow/applications/cfd/acoustic/images/errors.png b/MindFlow/applications/cfd/acoustic/images/errors.png new file mode 100644 index 0000000000000000000000000000000000000000..4875564467a436f74d8afa8e85488a44331c40cf Binary files /dev/null and b/MindFlow/applications/cfd/acoustic/images/errors.png differ diff --git a/MindFlow/applications/cfd/acoustic/images/wave.gif b/MindFlow/applications/cfd/acoustic/images/wave.gif new file mode 100644 index 0000000000000000000000000000000000000000..db5c62334ec40048959008f814ad7e61ede9cb96 Binary files /dev/null and b/MindFlow/applications/cfd/acoustic/images/wave.gif differ diff --git a/MindFlow/applications/cfd/acoustic/solve_acoustic.py b/MindFlow/applications/cfd/acoustic/solve_acoustic.py new file mode 100644 index 0000000000000000000000000000000000000000..45ee3e3da0b271c26394ae5e4da6344f27fea945 --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/solve_acoustic.py @@ -0,0 +1,169 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""""Solve 2D acoustic equation""""" +import os +import argparse +import numpy as np +from scipy.interpolate import interp1d +import pandas as pd +import mindspore as ms +from mindspore import ops, Tensor, numpy as mnp + +from mindflow.utils import load_yaml_config + +from cbs.cbs import CBS +from src import utils, visual + + +def solve_cbs(cbs, velo, slocs, omegas, receiver_zs=None, dx=1., n_batches=1): + ''' + Solve for different source locations and frequencies using CBS (Convergent Born series) solver + Args: + velo: 2d Tensor, the velocity field + slocs: (ns, 2) array, the source locations (z, x coordinates) to be solved + omegas: 1d array, the frequencies to be solved on + receiver_zs: 1d array, z coordinates of signal receivers. + Default is None, which means all signals will be received + dx: float, the grid interval along x & z directions + n_batches: int, the number of batches for frequencies to be diveded into + Returns: + u_real, u_imag: + ''' + no = len(omegas) + ns = len(slocs) + nz, nx = velo.shape + + if receiver_zs is None: + receiver_zs = np.arange(nz) * dx + + krs = Tensor(np.rint(np.divide(receiver_zs, dx)), dtype=ms.int32, const_arg=False) + omegas = Tensor(omegas, dtype=ms.float32, const_arg=False) + + masks = Tensor(utils.sloc2mask(slocs, (nz, nx), (dx, dx)), dtype=ms.float32, const_arg=False) # shape (ns, nz, nx) + + urs = [] # note: do hold the solution of each batch in list and cat to Tensor later + uis = [] # note: do not hold them by modifying Tensor slices, dynamic shape and error would be caused + errs = [] + + for n, i in enumerate(range(0, no, no // n_batches)): + j = i + min(no // n_batches, no - i) + + print(f'batch {n}, omega {float(omegas[i]):.4f} ~ {float(omegas[j-1]):.4f}') + + c_star = velo / dx / omegas[i:j].reshape(-1, 1, 1) + f_star = masks.reshape(ns, 1, nz, nx) + c_star, f_star = mnp.broadcast_arrays(c_star, f_star) + + c_star = c_star.reshape(-1, 1, *c_star.shape[2:]) # shape (ns * no, 1, nz, nx) + f_star = f_star.reshape(-1, 1, *f_star.shape[2:]) # shape (ns * no, 1, nz, nx) + + ur, ui, err = cbs.solve(c_star, f_star, tol=1e-3) + + urs.append(ur[..., krs, :].reshape(ns, -1, len(krs), nx)) + uis.append(ui[..., krs, :].reshape(ns, -1, len(krs), nx)) + errs.append(np.reshape(err, (-1, ns, j - i))) + + u_real = ops.cat(urs, axis=1) # shape (ns, no, len(krs), nx) + u_imag = ops.cat(uis, axis=1) # shape (ns, no, len(krs), nx) + + return u_real, u_imag, errs + + +def main(config): + data_config = config['data'] + solve_config = config['solve'] + summary_config = config['summary'] + + # read time & frequency points + dt = solve_config['dt'] + nt = solve_config['nt'] + ts = np.arange(nt) * dt + omegas_all = np.fft.rfftfreq(nt) * (2 * np.pi / dt) + + # read source locations + df = pd.read_csv(os.path.join(data_config['root_dir'], data_config['source_locations']), index_col=0) + slocs = df[['y', 'x']].values # shape (ns, 2) + + # read & interp source wave + df = pd.read_csv(os.path.join(data_config['root_dir'], data_config['source_wave'])) + inter_func = interp1d(df.t, df.f, bounds_error=False, fill_value=0) + src_waves = inter_func(ts) # shape (nt) + src_amplitudes = np.fft.rfft(src_waves) # shape (nt//2+1) + + # read velocity array + velo = np.load(os.path.join(data_config['root_dir'], data_config['velocity_field'])) + nz, nx = velo.shape + dx = data_config['velocity_dx'] + + # select omegas + no = len(omegas_all) // solve_config['downsample_rate'] + + if solve_config['downsample_mode'] == 'exp': + omegas_sel = np.exp(np.linspace(np.log(omegas_all[1]), np.log(omegas_all[-1]), no)) + elif solve_config['downsample_mode'] == 'square': + omegas_sel = np.linspace(omegas_all[1]**.5, omegas_all[-1]**.5, no)**2 + else: + omegas_sel = np.linspace(omegas_all[1], omegas_all[-1], no) + + # send to NPU and perform computation + os.makedirs(summary_config['root_dir'], exist_ok=True) + velo = Tensor(velo, dtype=ms.float32, const_arg=True) + cbs = CBS((nz, nx), remove_pml=False) + + ur, ui, errs = solve_cbs( + cbs, velo, slocs, omegas_sel, dx=dx, n_batches=solve_config['n_batches']) # shape (ns, no, len(receiver_zs), nx) + + u_star = np.squeeze(ur.numpy() + 1j * ui.numpy()) # shape (ns, no, len(krs), nx) + np.save(os.path.join(summary_config['root_dir'], 'u_star.npy'), u_star) + + # recover dimension and interpolate to full frequency domain + u_star /= omegas_sel.reshape(-1, 1, 1)**2 + u_star = interp1d(omegas_sel, u_star, axis=1, kind='cubic', bounds_error=False, fill_value=0)(omegas_all) + u_star *= src_amplitudes.reshape(-1, 1, 1) + + # transform to time domain + u_time = np.fft.irfft(u_star, axis=1) + np.save(os.path.join(summary_config['root_dir'], 'u_time.npy'), u_time) + + # visualize the result + u_time = np.load(os.path.join(summary_config['root_dir'], 'u_time.npy')) + visual.anim(velo.numpy(), u_time, ts, os.path.join(summary_config['root_dir'], 'wave.gif')) + visual.plot_errs(errs, os.path.join(summary_config['root_dir'], 'errors.png')) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Solve 2D acoustic equation with CBS") + parser.add_argument( + "--mode", + type=str, + default="GRAPH", + choices=["GRAPH", "PYNATIVE"], + help="Running in GRAPH_MODE OR PYNATIVE_MODE", + ) + parser.add_argument( + "--device_id", + type=int, + default=utils.choose_free_npu(), + help="ID of the target device", + ) + parser.add_argument("--config_file_path", type=str, default="./config.yaml") + args = parser.parse_args() + + ms.set_context( + device_target='Ascend', + device_id=args.device_id, + mode=ms.GRAPH_MODE if args.mode.upper().startswith("GRAPH") else ms.PYNATIVE_MODE) + + main(load_yaml_config(args.config_file_path)) diff --git a/MindFlow/applications/cfd/acoustic/src/__init__.py b/MindFlow/applications/cfd/acoustic/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d1443867989bc930cbf052499c248607032bcd35 --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/src/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/MindFlow/applications/cfd/acoustic/src/utils.py b/MindFlow/applications/cfd/acoustic/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..df9853954780b5b44e265579bdc89081f94d57e0 --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/src/utils.py @@ -0,0 +1,89 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""util functions""" +import os +import numpy as np + + +def choose_free_npu(index='HBM Usage Rate', n=None): + ''' + Call the 'npu-smi' command on Linux to look for the most available NPU + Args: + index: str, the index for NPU availability, default is HBM Usage Rate + n: int, number of NPUs to return + Returns: + device_id: int, the device_id for the most available and healthy NPU, -1 means no healthy NPU exists + ''' + usages = [(999, -1)] + + for i in range(8): + # check whether the i-th NPU exists and works healthily + info = os.popen(f'npu-smi info -t health -i {i}') + + exist_flag = False + healthy_flag = True + + for s in info: + if 'Health' in s: + exist_flag = True + if s.split(':')[-1].strip() != 'OK': + healthy_flag = False + break + + if not exist_flag: continue + if not healthy_flag: continue + + # check the usage of the i-th NPU + info = os.popen(f'npu-smi info -t usages -i {i}') + + for s in info: + if index in s: + usages.append((float(s.split(':')[-1]), i)) # record the HBM usage rate of the current NPU + break + + if not n: + return min(usages)[1] + + return [i for usage, i in sorted(usages)[:n]] + +def sloc2mask(slocs, shape, dxs=None): + ''' + Convert source locations to masks with numpy + Args: + slocs: 2d array (ns, ndim), ns is the number of source locations, + and the last dimension indicates the coordinates in the order of (z, y, x). + shape: 1d array (ndim) + dxs: 1d array (ndim), the grid intervals in each dimension + Returns: + mask: (ndim+1)-d array, the first dimension is the batch dimension, + and the last ndim dimensions are space dimensions + ''' + if dxs is None: + dxs = np.ones_like(shape) + + assert np.shape(slocs)[-1] == len(shape) == len(dxs) + + mask = np.zeros([*np.shape(slocs)[:-1], *shape]) + + for i, sloc in enumerate(np.reshape(slocs, [-1, len(shape)])): + sidx = np.rint(np.divide(sloc, dxs)).astype(int) + mask.reshape(-1, *shape)[tuple([i, *sidx])] = 1 + + return mask + +def mask2sloc(masks, ndim, dxs=None): + ''' convert masks to source locations with numpy ''' + sidxs = np.argwhere(masks)[:, -ndim:].reshape(*masks.shape[:-ndim], ndim) + return sidxs if dxs is None else np.multiply(sidxs, dxs) diff --git a/MindFlow/applications/cfd/acoustic/src/visual.py b/MindFlow/applications/cfd/acoustic/src/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..34ed3a497b7748e98d2d925b32477273be706e6d --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/src/visual.py @@ -0,0 +1,84 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""visualize the results and computing history""" +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.animation as animation + + +def anim(velo, us, ts, figname): + ''' + Animate the wave field in time domain and generate gif + ''' + ns, nt, _, _ = us.shape + + nrows = 1 + ncols = ns + 1 + + fig, axs = plt.subplots( + nrows, ncols, sharex=True, sharey=True, squeeze=False, + constrained_layout=True, figsize=(3 * ncols, 3 * nrows)) + + axs[0, 0].contourf(velo, cmap='seismic', extend='both') + axs[0, 0].set_title('velocity') + + for ax in axs.ravel(): + ax.yaxis.set_inverted(True) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_aspect('equal', adjustable='box') + + for i, ax in enumerate(axs[0, 1:]): + ax.set_title(f'shot_{i}') + for ax in axs[-1]: + ax.set_xlabel('x') + for ax in axs[:, 0]: + ax.set_ylabel('z') + + handles = [] + + def run(n): + print(f'animating wave field {n} / {nt}') + + while handles: + handles.pop().remove() + + for j, u in enumerate(us): + ax = axs[0, j + 1] + cnt = ax.contourf(u[n], cmap='bwr', extend='both', levels=np.linspace(-3, 3, 10) * u.std()) + handles.append(cnt) + ax.set_title(f"t = {ts[n]:.3f} [s]") + + ani = animation.FuncAnimation(fig, run, frames=range(0, nt, 5), interval=100, repeat_delay=1000, repeat=True) + + ani.save(figname, dpi=100) + plt.close() + +def plot_errs(errs, figname): + ''' + Plot the convergence history + Args: + errs: list of arrays, each array has shape (iters, ns, no_batchsize) + ''' + fig, ax = plt.subplots(constrained_layout=True) + for err in errs: + data = np.mean(err, axis=1) # average out the slocs dimension + ax.semilogy(data) + ax.grid() + ax.legend(['different frequencies']) + ax.set_xlabel('iteration') + ax.set_ylabel('residual') + fig.savefig(figname, dpi=300) + plt.close()