diff --git a/CONTRIBUTION.md b/CONTRIBUTION.md
index 5004a1092fc13d5b66d2cc50465a2c60c63f1666..083b5fa59f5a782de96d37be7626f46fc8451fa3 100644
--- a/CONTRIBUTION.md
+++ b/CONTRIBUTION.md
@@ -1,8 +1,9 @@
-# **Mindscience贡献指南**
+# Mindscience贡献指南
-- [Mindscience贡献指南](#Mindscience贡献指南)
+
+- [Mindscience贡献指南](#mindscience贡献指南)
- [如何成为mindscience仓的贡献者](#如何成为mindscience仓的贡献者)
- [一、提交合并请求,为mindscience仓贡献自己的力量](#一提交合并请求为mindscience仓贡献自己的力量)
- [二、新增测试用例,看护代码功能](#二新增测试用例看护代码功能)
@@ -12,8 +13,17 @@
- [三、查看仓及分支信息](#三查看仓及分支信息)
- [四、修改代码后提交commit以及多个commit合并](#四修改代码后提交commit以及多个commit合并)
- [五、更新本地代码,同步仓代码——冲突解决](#五更新本地代码同步仓代码冲突解决)
- - [六、将本地修改代码推到远端仓,提起向主仓合并请求的PR](#六将本地修改代码推到远端仓提起向主仓合并请求的pr)
+ - [六、将本地修改代码推到远端仓,提起向主仓合并请求的pr](#六将本地修改代码推到远端仓提起向主仓合并请求的pr)
- [七、附加](#七附加)
+ - [代码开发合入CheckList](#代码开发合入checklist)
+ - [API代码](#api代码)
+ - [案例目录格式](#案例目录格式)
+ - [单个案例目录格式](#单个案例目录格式)
+ - [多个案例目录格式](#多个案例目录格式)
+ - [训练文件格式](#训练文件格式)
+ - [配置文件格式](#配置文件格式)
+ - [README文件格式](#readme文件格式)
+ - [Jupyter Notebook文件格式](#jupyter-notebook文件格式)
@@ -39,7 +49,7 @@
-- 点击新建合并请求后,需要进行源分支名的选择,目标分支名的选择,标题输入,以及简要说明修改点等操作( **注意:合并标题格式为[SPONG]+内容** )
+- 点击新建合并请求后,需要进行源分支名的选择,目标分支名的选择,标题输入,以及简要说明修改点等操作(**注意:合并标题格式为[SPONG]+内容**)
- 在新建合并请求的右下角需进行关联Issue操作,每个合并请求的合入都要有对应的Issue,如果没有相关的Issue,可以自行创建,请记得关联完Issue后将(合并后关闭提到的Issue)前面勾勾取消,然后点击创建合并请求操作
@@ -47,7 +57,7 @@
-- 关联Issue处如果没有可选择的Issue关联,可以在主仓新建一个Issue,如果有则直接忽略此步。在主仓中点击新建Issue,根据合并请求的类型选择对应Issue类型,输入标题后,点击创建即可,这样在新建合并请求的关联Issue操作中就可以选择刚刚创建的Issue( **注意:Issue标题格式为[SPONGE]+内容** )
+- 关联Issue处如果没有可选择的Issue关联,可以在主仓新建一个Issue,如果有则直接忽略此步。在主仓中点击新建Issue,根据合并请求的类型选择对应Issue类型,输入标题后,点击创建即可,这样在新建合并请求的关联Issue操作中就可以选择刚刚创建的Issue(**注意:Issue标题格式为[SPONGE]+内容**)

@@ -71,7 +81,7 @@
### **二、新增测试用例,看护代码功能**
-- 对于贡献者来说,如果需要新增门禁冒烟测试用例来维护自己代码功能,可以在代码目录的mindscience/tests/st下,新增测试用例代码,这样可以保证其他人合入代码时不会影响自己代码功能( **注意:测试用例运行时间必须尽量短,受设备资源限制,太久的用例不适合作为门禁用例看护** )
+- 对于贡献者来说,如果需要新增门禁冒烟测试用例来维护自己代码功能,可以在代码目录的mindscience/tests/st下,新增测试用例代码,这样可以保证其他人合入代码时不会影响自己代码功能(**注意:测试用例运行时间必须尽量短,受设备资源限制,太久的用例不适合作为门禁用例看护**)
- 系统级测试用例,此阶段的用例是在whl包安装完成后启动,因此可以调用whl包中的任何函数,需要注意,系统级测试用例中需要添加(import pytest),并且在函数入口处新增pytest的标识,该标识可以使门禁任务识别到函数入口
@@ -303,3 +313,303 @@ git commit --amend
```bash
git reset --hard commit_id
```
+
+## 代码开发合入CheckList
+
+本文档介绍如何向MindFlow合入代码,包括合入前需要准备的文件、数据,合入步骤以及需要注意的事项,帮助贡献者更高效地进行代码合入。
+
+如果缺少调试代码的硬件环境,可以参考[启智社区云脑使用指南](https://download-mindspore.osinfra.cn/mindscience/mindflow/tutorials/%E5%90%AF%E6%99%BA%E6%8C%87%E5%8D%97.pdf), [NPU使用录屏](https://download-mindspore.osinfra.cn/mindscience/mindflow/tutorials/npu%E4%BD%BF%E7%94%A8.MP4), [GPU使用录屏](https://download-mindspore.osinfra.cn/mindscience/mindflow/tutorials/gpu%E4%BD%BF%E7%94%A8.MP4)。
+
+### API代码
+
+API代码主要指合入`MindFlow/mindflow`目录的代码,主要为案例提供高效、易用的公共API接口,因此API代码编写时需要注意以下几点:
+
+1、考虑在多个案例上的可扩展性,避免'Hard Code',在维度、深度等参量上预留足够的入参,以供用户根据实际情况选择,注意非法入参的检查;
+
+2、入参命名上,MindFlow追求尽量统一,因此新的API合入时,需要与原有API的入参尽量对齐,新的入参命名可与Commiter联系;
+
+3、API的存放位置需根据MindFlow的套件架构决定,注意更新`__init__.py`文件和`cmake/package.cmake`文件;
+
+4、API文档包含两部分,一个是代码注释部分,一个是`mindscience/docs/api_python/mindflow`和`mindscience/docs/api_python_en/mindflow`中的中英文文档;
+
+5、API相关测试用例来进行维护,保证其随时可用,测试用例提交在`mindscience/tests`中,可根据具体用例修改,但运行时间不宜过长,结果检查合理;
+
+### 案例目录格式
+
+案例代码主要指合入`MindFlow/applications`目录的代码,需要根据研究范式,归入`physics_driven`、`data_driven`、`data_mechanism_fusion`、`cfd`几个目录中。
+
+- **必须** Jupyter Notebook中英文:为用户提供逐行的代码实现方式,详细讲解案例的实现方式和运行结果。
+
+- **必须** `images`:包含了README、notebook等文件里的所有图片。
+
+- **必须** `src`:为了保证训练代码的整洁性,可以抽取的函数和类可以统一放在src目录中,`__init__.py`一般为必须,`dataset.py`中包含数据集相关函数和类,`model.py`中包含模型相关函数和类,`utils.py`中包含工具函数和类,外部文件的调用统一从src导入。
+
+- **必须** 参数文件:案例中具体参数的配置,一般采用yaml文件,为了方便查看,按照优化器、模型等进行分类。
+
+- **必须** 训练脚本:案例的训练和验证脚本,在训练时除特殊情况,必须有测试集进行验证;训练脚本中的代码应该尽量简洁,复杂的调用封装到后端函数里。
+
+> **注意**:类和函数中需要避免'Hard Code',变量名需要有实际含义;尽量避免使用'Magic Number',必要的需要在注释里说明;超过50行以上的代码可以考虑抽取出函数调用,减少重复代码;函数的功能尽可能单一,遵从'高内聚,低耦合'原则。
+
+### 单个案例目录格式
+
+单一的案例代码如[`PINNs求解Burgers`](./applications/physics_driven/burgers)为例,代码目录分成以下结构:
+
+```shell
+.
+├── images
+│ ├── background.png
+│ └── result.png
+├── src
+│ ├── __init__.py
+│ ├── dataset.py
+│ ├── model.py
+│ └── utils.py
+├── configs
+│ ├── fno1d.yaml
+├── README.md
+├── README_CN.md
+├── problem.ipynb
+├── problem_CN.ipynb
+├── burgers_cfg.yaml
+├── eval.py
+└── train.py
+```
+
+### 多个案例目录格式
+
+有时,有多个案例会使用相同的模型和方法,使用不同的数据集,为了避免代码和文档的重复,`src`目录下统一存放所有案例公共的代码和每个案例自定义的代码,`images`目录统一存放图片文件,`README.md`文件在总体上介绍模型方法和所有的案例,`problem.ipynb`文件介绍具体的案例代码,所有案例具有相同的入口,在命令行里通过指定参数来确定运行的具体案例,文件格式如下:
+
+```shell
+.
+├──images
+│ ├──background.png
+│ ├──result1.png
+│ ├──result2.png
+│ └──result3.png
+├──src
+│ ├──__init__.py
+│ ├──dataset.py
+│ ├──model.py
+│ └──utils.py
+├──configs
+│ ├──fno1d.yaml
+├──README.md
+├──README_CN.md
+├──problem.ipynb
+├──problem_CN.ipynb
+├──problem_cfg.yaml
+├──eval.py
+└──train.py
+```
+
+外层训练/测试文件调用的方式如下:
+
+```python
+...
+parser = argparse.ArgumentParser(description="Cae-Lstm")
+parser.add_argument("--case", type=str, default='riemann', choices=['riemann', 'kh', 'sod'],
+ help="Which case to run")
+...
+args = parser.parse_args()
+...
+model = Model()
+if args.case == 'riemann':
+ dataset = create_riemann_dataset()
+elif args.case == 'kh':
+ dataset = create_kh_dataset()
+else:
+ dataset = create_sod_dataset()
+model.train(dataset)
+...
+```
+
+### 训练文件格式
+
+训练文件train.py为模型训练的入口,格式如下:
+
+```python
+import os
+import time
+import argparse
+import numpy as np
+
+from mindspore import context, nn, Tensor, set_seed, ops, data_sink, jit, save_checkpoint
+from mindspore import dtype as mstype
+
+from mindflow import FNO1D, load_yaml_config, get_warmup_cosine_annealing_lr
+from mindflow.pde import FlowWithLoss
+
+from src import create_training_dataset, visual, calculate_l2_error
+# 相关依赖导入,按照python官方库、第三方库、mindflow、src的顺序导入,导入mindflow时,精确到二级目录
+
+set_seed(123456)
+np.random.seed(123456)
+# 设置随机数
+
+def parse_args():
+ '''Parse input args'''
+ parser = argparse.ArgumentParser(description='Problem description')
+ parser.add_argument("--config_file_path", type=str, default="./config.yaml")
+ parser.add_argument("--device_target", type=str, default="GPU", choices=["GPU", "Ascend"],
+ help="The target device to run, support 'Ascend', 'GPU'")
+ parser.add_argument("--device_id", type=int, default=3, help="ID of the target device")
+ parser.add_argument("--mode", type=str, default="GRAPH", choices=["GRAPH", "PYNATIVE"],
+ help="Context mode, support 'GRAPH', 'PYNATIVE'")
+ parser.add_argument("--save_graphs", type=bool, default=False, choices=[True, False],
+ help="Whether to save intermediate compilation graphs")
+ parser.add_argument("--save_graphs_path", type=str, default="./graphs")
+ input_args = parser.parse_args()
+ return input_args
+
+
+def train(input_args):
+ use_ascend = context.get_context(attr_key='device_target') == "Ascend"
+ # 读取训练配置
+ config = load_yaml_config(input_args.config_file_path)
+
+ # 创建训练集和测试集
+ train_dataset, test_dataset = create_training_dataset(data_params, shuffle=True)
+ # 初始化模型
+ model = Model(config)
+
+ problem = FlowWithLoss(model)
+ # 前向函数
+ def forward_fn(data, label):
+ ...
+
+ grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)
+ # 训练的前向和反向过程
+ @jit
+ def train_step(data, label):
+ ...
+ # 数据下沉
+ sink_process = data_sink(train_step, train_dataset, 1)
+
+ # 训练流程
+ for epoch in range(1, config["epochs"] + 1):
+ model.set_train()
+ train()
+ # 训练和验证函数,采用MindSpore函数式编程范式编写,注意打印内容尽量统一
+ print(f"epoch: {epoch} train loss: {step_train_loss} epoch time: {time.time() - time_beg:.2f}s")
+ # 验证
+ if epoch % config['eval_interval'] == 0:
+ model.set_train(False)
+ print("================================Start Evaluation================================")
+ eval()
+ print(f"epoch: {epoch} eval loss: {step_train_loss} epoch time: {time.time() - time_beg:.2f}s")
+ print("=================================End Evaluation=================================")
+ if epoch % config['save_ckpt_interval'] == 0:
+ save_checkpoint(model, 'my_model.ckpt')
+
+
+if __name__ == '__main__':
+ print(f"pid: {os.getpid()}")
+ print(datetime.datetime.now())
+ # 读取脚本入参
+ args = parse_args()
+
+ context.set_context(mode=context.GRAPH_MODE if args.mode.upper().startswith("GRAPH") else context.PYNATIVE_MODE,
+ device_target=args.device_target,
+ device_id=args.device_id)
+ print(f"Running in {args.mode.upper()} mode, using device id: {args.device_id}.")
+ # context设置,由于Ascend和GPU使用的差异,需要使用use_ascend变量进行判断
+ start_time = time.time()
+ # 调用训练函数
+ train(args)
+ print("End-to-End total time: {}s".format(time.time() - start_time))
+```
+
+### 配置文件格式
+
+参数按照模型、数据、优化器等类别分类,放在"./configs"目录下,配置中的路径参数都是根目录的相对路径。参数命名规范统一格式,格式如下:
+
+```yaml
+model:
+ in_channels: 3
+ out_channels: 3
+ height: 192
+ width: 384
+ encoder_depth: 6
+ decoder_depth: 6
+ decoder_num_heads: 16
+
+data:
+ train_dataset_path: "./dataset/test.npy"
+ test_dataset_path: "./dataset/train.npy"
+ grid_path: "./dataset/grid.npy"
+ batch_size: 32
+
+optimizer:
+ epochs: 1000
+ lr: 0.0005
+ wave_level: 1
+```
+
+### README文件格式
+
+其中,总目录中的README对整体背景、技术路线、结果进行讲解,在每个案例中,可以分别在案例的角度描述,注意整体和局部的详略关系,避免重复描述和重复代码。
+
+【必须】README.md和README_CN.md,中英文README文件,一般包含以下部分:
+
+```md
+# 标题
+
+## 概述
+
+简单介绍一下案例的背景、方法、数据集、效果等。
+
+## 快速开始
+
+为用户提供快速运行脚本的方法,一般提供脚本调用和Jupyter Notebook两种方式。其中,脚本调用需要展示启动命令的入参含义
+
+### 训练方式一:在命令行中调用`train.py`脚本
+
+python train.py --config_file_path ./configs/burgers.yaml --mode GRAPH --device_target Ascend --device_id 0
+
+其中,
+`--config_file_path`表示参数文件的路径,默认值'./burgers_cfg.yaml';
+
+`--mode`表示运行的模式,'GRAPH'表示静态图模式, 'PYNATIVE'表示动态图模式,默认值'GRAPH';
+
+`--device_target`表示使用的计算平台类型,可以选择'Ascend'或'GPU',默认值'Ascend';
+
+`--device_id`表示使用的计算卡编号,可按照实际情况填写,默认值0;
+
+### 训练方式二:运行Jupyter Notebook
+
+您可以使用中英文版本的Jupyter Notebook(附链接)逐行运行训练和验证代码。
+
+## 结果展示
+
+用1-2张图的方式展示模型推理的效果,最好为gif。
+
+## 性能
+
+如果案例涉及到GPU和Ascend双后端,则需要用表格的形式展示训练的主要性能指标进行对比。
+
+| 参数 | NPU | GPU |
+|:-------------------:|:------------------------:|:------------:|
+| 硬件资源 | Ascend, 显存32G | NVIDIA V100, 显存32G |
+| MindSpore版本 | >=2.0.0 | >=2.0.0 |
+| 数据集 | [Burgers数据集](https://download.mindspore.cn/mindscience/mindflow/dataset/applications/physics_driven/burgers_pinns/) | [Burgers数据集](https://download.mindspore.cn/mindscience/mindflow/dataset/applications/physics_driven/burgers_pinns/) |
+| 参数量 | 6e4 | 6e4 |
+| 训练参数 | batch_size=8192, steps_per_epoch=1, epochs=15000 | batch_size=8192, steps_per_epoch=1, epochs=15000 |
+| 测试参数 | batch_size=8192, steps=4 | batch_size=8192, steps=4 |
+| 优化器 | Adam | Adam |
+| 训练损失(MSE) | 0.001 | 0.0001 |
+| 验证损失(RMSE) | 0.010 | 0.008 |
+| 训练速度(ms/step) | 10 | 130 |
+
+## 贡献者
+
+gitee id: [id](开发者gitee个人空间的链接)
+
+email: myemail@163.com
+
+```
+
+### Jupyter Notebook文件格式
+
+Jupyter Notebook文件格式可参考[2D_steady_CN.ipynb](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_driven/airfoil/2D_steady/2D_steady_CN.ipynb)。
+
+将主要代码模块从训练脚本中抽出,有序分块放入Jupyter Notebook文件。Jupyter Notebook一般包含`概述`、`问题背景`、`技术路径`、`依赖导入`、`数据集制作`、`模型搭建`、`模型训练`、`结果展示`等部分。在每个部分,应当对代码重要内容进行说明,保证按照说明执行代码块能正常运行。
diff --git a/MindChemistry/applications/crystalflow/.gitignore b/MindChemistry/applications/crystalflow/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..208f562873e612552476fe64804fe4353ab8f6dd
--- /dev/null
+++ b/MindChemistry/applications/crystalflow/.gitignore
@@ -0,0 +1,11 @@
+dataset/
+*.log
+*.npy
+ckpt/
+dataset.zip
+rank_0/
+test_mind_cspnet.py
+torch2ms_ckpt/
+*.ipynb
+*.ckpt
+ignore/
diff --git a/mindscience/ccsrc/__init__.py b/mindscience/ccsrc/__init__.py
index 69930d7b2c715acde745900751d5393a42e54afa..34c6f9f7fbf9098082bfe036fe13775e0411e675 100644
--- a/mindscience/ccsrc/__init__.py
+++ b/mindscience/ccsrc/__init__.py
@@ -13,3 +13,5 @@
# limitations under the License.
# ============================================================================
"""init"""
+
+__all__ = []
\ No newline at end of file
diff --git a/mindscience/common/__init__.py b/mindscience/common/__init__.py
index 5853261e40e99966b0519166f79334673d6857ac..de4785404ebb24dd07a835fa561ff070bb9c0243 100644
--- a/mindscience/common/__init__.py
+++ b/mindscience/common/__init__.py
@@ -17,6 +17,9 @@ from .lr_scheduler import get_poly_lr, get_multi_step_lr, get_warmup_cosine_anne
from .losses import get_loss_metric, WaveletTransformLoss, MTLWeightedLoss, RelativeRMSELoss
from .derivatives import batched_hessian, batched_jacobian
from .optimizers import AdaHessian
+from .math import get_grid_1d, get_grid_2d, get_grid_3d
+from .utils import to_2tuple, unpatchify, patchify, get_2d_sin_cos_pos_embed, \
+ pixel_shuffle, pixel_unshuffle, PixelShuffle, PixelUnshuffle, SpectralNorm
__all__ = ["get_poly_lr",
"get_multi_step_lr",
@@ -28,6 +31,10 @@ __all__ = ["get_poly_lr",
"batched_hessian",
"batched_jacobian",
"AdaHessian",
+ "get_grid_1d", "get_grid_2d", "get_grid_3d",
+ "to_2tuple", "to_3tuple", "unpatchify", "patchify", "get_2d_sin_cos_pos_embed",
+ "pixel_shuffle", "pixel_unshuffle", "PixelShuffle", "PixelUnshuffle",
+ "SpectralNorm"
]
__all__.sort()
diff --git a/mindscience/common/math.py b/mindscience/common/math.py
index 474425204c5665566af88218bf9f300a03722122..55018d1b575bfa50cef9c93e717015b6f83245ed 100644
--- a/mindscience/common/math.py
+++ b/mindscience/common/math.py
@@ -17,7 +17,7 @@ math operators
'''
import numpy as np
-from ..cell.utils import to_2tuple, to_3tuple
+from .utils import to_2tuple, to_3tuple
__all__ = ['get_grid_1d', 'get_grid_2d', 'get_grid_3d']
diff --git a/mindscience/common/utils.py b/mindscience/common/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a41d34c0a03b857ae75261fa52f1809fb61b9371
--- /dev/null
+++ b/mindscience/common/utils.py
@@ -0,0 +1,552 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+'''utils'''
+import numpy as np
+import mindspore.ops.operations as P
+from mindspore import nn, ops, Parameter, numpy as msnp
+from mindspore.common.initializer import initializer, Normal
+
+
+def to_3tuple(t):
+ """
+ Args:
+ t (Union[int, tuple(int)]): The grid height and width.
+
+ Returns:
+ Same as input or a tuple as (t,t,t).
+
+ """
+ return t if isinstance(t, tuple) else (t, t, t)
+
+
+def to_2tuple(t):
+ """
+ Args:
+ t (Union[int, tuple(int)]): The grid height and width.
+
+ Returns:
+ Same as input or a tuple as (t,t).
+
+ """
+ return t if isinstance(t, tuple) else (t, t)
+
+
+def get_2d_sin_cos_pos_embed(embed_dim, grid_size):
+ r"""
+ Args:
+ embed_dim (int): The output dimension for each position.
+ grid_size (tuple(int)): The grid height and width.
+
+ Returns:
+ The numpy array with shape of :math:`(1, grid\_height*grid\_width, embed\_dim)`
+
+ """
+ grid_size = to_2tuple(grid_size)
+ grid_height = np.arange(grid_size[0], dtype=np.float32)
+ grid_width = np.arange(grid_size[1], dtype=np.float32)
+ grid = np.meshgrid(grid_width, grid_height) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
+ pos_embed = get_2d_sin_cos_pos_embed_from_grid(embed_dim, grid)
+ pos_embed = np.expand_dims(pos_embed, 0)
+ return pos_embed
+
+
+def get_2d_sin_cos_pos_embed_from_grid(embed_dim, grid):
+ r"""
+ use half of dimensions to encode grid_height
+
+ Args:
+ embed_dim (int): output dimension for each position.
+ grid (int): a numpy array of positions to be encoded: size (M,).
+
+ Returns:
+ The numpy array with shape of :math:`(M/2, embed\_dim)`
+ """
+ emb_height = get_1d_sin_cos_pos_embed_from_grid(
+ embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_width = get_1d_sin_cos_pos_embed_from_grid(
+ embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_height, emb_width], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sin_cos_pos_embed_from_grid(embed_dim, pos):
+ r"""
+ Args:
+ embed_dim (int): output dimension for each position.
+ pos (int): a numpy array of positions to be encoded: size (M,).
+
+ Returns:
+ The numpy array with shape of :math:`(M, embed\_dim)`
+ """
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000 ** omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+def patchify(label, patch_size=16):
+ """
+ Args:
+ label (Union[int, float]): output dimension for each position.
+ patch_size (int): The patch size of image. Default: ``16``.
+
+ Returns:
+ The numpy array with new shape of :math:`(H, W)`.
+ """
+ label_shape = label.shape
+ label = np.reshape(label, (label_shape[0] // patch_size,
+ patch_size,
+ label_shape[1] // patch_size,
+ patch_size,
+ label_shape[2]))
+ label = np.transpose(label, (0, 2, 1, 3, 4))
+ label_new_shape = label.shape
+ label = np.reshape(label, (label_new_shape[0] * label_new_shape[1],
+ label_new_shape[2] * label_new_shape[3] * label_new_shape[4]))
+ return label
+
+
+def unpatchify(labels, img_size=(192, 384), patch_size=16, nchw=False):
+ """
+ Args:
+ labels (Union[int, float]): output dimension for each position.
+ img_size (tuple(int)): Input image size. Default (192, 384).
+ patch_size (int): The patch size of image. Default: 16.
+ nchw (bool): If True, the unpatchify shape contains N, C, H, W.
+
+ Returns:
+ The tensor with shape of :math:`(N, H, W, C)`.
+ """
+ label_shape = labels.shape
+ output_dim = label_shape[-1] // (patch_size * patch_size)
+ labels = P.Reshape()(labels, (label_shape[0],
+ img_size[0] // patch_size,
+ img_size[1] // patch_size,
+ patch_size,
+ patch_size,
+ output_dim))
+
+ labels = P.Transpose()(labels, (0, 1, 3, 2, 4, 5))
+ labels = P.Reshape()(labels, (label_shape[0],
+ img_size[0],
+ img_size[1],
+ output_dim))
+ if nchw:
+ labels = P.Transpose()(labels, (0, 3, 1, 2))
+ return labels
+
+def to_2tuple(t):
+ """
+ Args:
+ t (Union[int, tuple(int)]): The grid height and width.
+
+ Returns:
+ Same as input or a tuple as (t,t).
+
+ """
+ return t if isinstance(t, tuple) else (t, t)
+
+
+def get_2d_sin_cos_pos_embed(embed_dim, grid_size):
+ """
+ Args:
+ embed_dim (int): The output dimension for each position.
+ grid_size (tuple(int)): The grid height and width.
+
+ Returns:
+ The numpy array with shape of (1, grid_height*grid_width, embed_dim)
+
+ """
+ grid_size = to_2tuple(grid_size)
+ grid_height = np.arange(grid_size[0], dtype=np.float32)
+ grid_width = np.arange(grid_size[1], dtype=np.float32)
+ grid = np.meshgrid(grid_width, grid_height) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
+ pos_embed = get_2d_sin_cos_pos_embed_from_grid(embed_dim, grid)
+ pos_embed = np.expand_dims(pos_embed, 0)
+ return pos_embed
+
+
+def get_2d_sin_cos_pos_embed_from_grid(embed_dim, grid):
+ """
+ use half of dimensions to encode grid_height
+
+ Args:
+ embed_dim (int): output dimension for each position.
+ grid (int): a numpy array of positions to be encoded: size (M,).
+
+ Returns:
+ The numpy array with shape of (M/2, embed_dim)
+ """
+ emb_height = get_1d_sin_cos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_width = get_1d_sin_cos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_height, emb_width], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sin_cos_pos_embed_from_grid(embed_dim, pos):
+ """
+ Args:
+ embed_dim (int): output dimension for each position.
+ pos (int): a numpy array of positions to be encoded: size (M,).
+
+ Returns:
+ The numpy array with shape of (M, embed_dim)
+ """
+ omega = np.arange(embed_dim // 2, dtype=np.float)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000 ** omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+def patchify(label, patch_size=16):
+ """
+ Args:
+ label (Union[int, float]): output dimension for each position.
+ patch_size (int): The patch size of image. Default: 16.
+
+ Returns:
+ The numpy array with new shape of (H, W).
+ """
+ label_shape = label.shape
+ label = np.reshape(label, (label_shape[0] // patch_size,
+ patch_size,
+ label_shape[1] // patch_size,
+ patch_size,
+ label_shape[2]))
+ label = np.transpose(label, (0, 2, 1, 3, 4))
+ label_new_shape = label.shape
+ label = np.reshape(label, (label_new_shape[0] * label_new_shape[1],
+ label_new_shape[2] * label_new_shape[3] * label_new_shape[4]))
+ return label
+
+
+def unpatchify(labels, img_size=(192, 384), patch_size=16, nchw=False):
+ """
+ Args:
+ labels (Union[int, float]): output dimension for each position.
+ img_size (tuple(int)): Input image size. Default (192, 384).
+ patch_size (int): The patch size of image. Default: 16.
+ nchw (bool): If True, the unpatchify shape contains N, C, H, W.
+
+ Returns:
+ The tensor with shape of (N, H, W, C).
+ """
+ label_shape = labels.shape
+ output_dim = label_shape[-1] // (patch_size * patch_size)
+ labels = P.Reshape()(labels, (label_shape[0],
+ img_size[0] // patch_size,
+ img_size[1] // patch_size,
+ patch_size,
+ patch_size,
+ output_dim))
+
+ labels = P.Transpose()(labels, (0, 1, 3, 2, 4, 5))
+ labels = P.Reshape()(labels, (label_shape[0],
+ img_size[0],
+ img_size[1],
+ output_dim))
+ if nchw:
+ labels = P.Transpose()(labels, (0, 3, 1, 2))
+ return labels
+
+
+class SpectralNorm(nn.Cell):
+ """Applies spectral normalization to a parameter in the given module.
+
+ Spectral normalization stabilizes the training of discriminators (critics)
+ in Generative Adversarial Networks (GANs) by rescaling the weight tensor
+ with spectral norm.
+
+ Args:
+ module (nn.Cell): containing module.
+ n_power_iterations (int): number of power iterations to calculate spectral norm.
+ dim (int): dimension corresponding to number of outputs.
+ eps (float): epsilon for numerical stability in calculating norms.
+
+ Inputs:
+ - **input** - The positional parameter of containing module.
+ - **kwargs** - The keyword parameter of containing module.
+
+ Outputs:
+ The forward propagation of containing module.
+ """
+ def __init__(
+ self,
+ module,
+ n_power_iterations: int = 1,
+ dim: int = 0,
+ eps: float = 1e-12
+ ) -> None:
+ super(SpectralNorm, self).__init__()
+ self.parametrizations = module
+ self.weight = module.weight.astype("float16")
+ self.use_weight_norm = True
+ ndim = self.weight.ndim
+ if dim >= ndim or dim < -ndim:
+ raise IndexError("Dimension out of range (expected to be in range of "
+ f"[-{ndim}, {ndim - 1}] but got {dim})")
+
+ if n_power_iterations <= 0:
+ raise ValueError('Expected n_power_iterations to be positive, but '
+ 'got n_power_iterations={}'.format(n_power_iterations))
+ self.dim = dim if dim >= 0 else dim + ndim
+ self.eps = eps
+ self.l2_normalize = ops.L2Normalize(epsilon=self.eps)
+ self.expand_dims = ops.ExpandDims()
+ self.assign = P.Assign()
+ if ndim > 1:
+ self.n_power_iterations = n_power_iterations
+ weight_mat = self._reshape_weight_to_matrix()
+
+ h, w = weight_mat.shape
+ u = initializer(Normal(1.0, 0), [h]).init_data()
+ v = initializer(Normal(1.0, 0), [w]).init_data()
+ self._u = Parameter(self.l2_normalize(u), requires_grad=False)
+ self._v = Parameter(self.l2_normalize(v), requires_grad=False)
+ self._u, self._v = self._power_method(weight_mat, 15)
+
+ def construct(self, *inputs, **kwargs):
+ """SpectralNorm forward function"""
+ if self.weight.ndim == 1:
+ # Faster and more exact path, no need to approximate anything
+ self.l2_normalize(self.weight)
+ self.assign(self.parametrizations.weight, self.weight)
+ else:
+ weight_mat = self._reshape_weight_to_matrix()
+ if self.use_weight_norm:
+ self._u, self._v = self._power_method(weight_mat, self.n_power_iterations)
+ # See above on why we need to clone
+ u = self._u.copy()
+ v = self._v.copy()
+ weight_mat = weight_mat.astype("float32")
+ sigma = ops.tensor_dot(u, msnp.multi_dot([weight_mat, self.expand_dims(v, -1)]), 1)
+
+ self.assign(self.parametrizations.weight, self.weight / sigma)
+
+ return self.parametrizations(*inputs, **kwargs)
+
+ def remove_weight_norm(self):
+ self.use_weight_norm = False
+
+ def _power_method(self, weight_mat, n_power_iterations):
+ for _ in range(n_power_iterations):
+ weight_mat = weight_mat.astype("float32")
+ self._u = self.l2_normalize(msnp.multi_dot([weight_mat, self.expand_dims(self._v, -1)]).flatten())
+ # +0
+ temp = msnp.multi_dot([weight_mat.T, self.expand_dims(self._u, -1)]).flatten()
+ self._v = self.l2_normalize(temp)
+ return self._u, self._v
+
+ def _reshape_weight_to_matrix(self):
+ # Precondition
+ if self.dim != 0:
+ # permute dim to front
+ input_perm = [d for d in range(self.weight.dim()) if d != self.dim]
+ input_perm.insert(0, self.dim)
+
+ self.weight = ops.transpose(self.weight, input_perm)
+
+ return self.weight.reshape(self.weight.shape[0], -1)
+
+
+def pixel_shuffle(x, upscale_factor):
+ r"""
+ Applies a pixel_shuffle operation over an input signal composed of several input planes. This is useful for
+ implementiong efficient sub-pixel convolution with a stride of :math:`1/r`. For more details, refer to
+ `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
+
`_ .
+
+ Typically, the `x` is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape
+ :math:`(*, C, H \times r, W \times r)`, where `r` is an upscale factor and `*` is zero or more batch dimensions.
+
+ Args:
+ x (Tensor): Tensor of shape :math:`(*, C \times r^2, H, W)` . The dimension of `x` is larger than 2, and the
+ length of third to last dimension can be divisible by `upscale_factor` squared.
+ upscale_factor (int): factor to increase spatial resolution by, and is a positive integer.
+
+ Returns:
+ - **output** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` .
+
+ Raises:
+ ValueError: If `upscale_factor` is not a positive integer.
+ ValueError: If the length of third to last dimension is not divisible by `upscale_factor` squared.
+ TypeError: If the dimension of `x` is less than 3.
+
+ Supported Platforms:
+ ``Ascend`` ``GPU`` ``CPU``
+ """
+ idx = x.shape
+ length = len(idx)
+ if length < 3:
+ raise TypeError(f"For pixel_shuffle, the dimension of `x` should be larger than 2, but got {length}.")
+ pre = idx[:-3]
+ c, h, w = idx[-3:]
+ if c % upscale_factor ** 2 != 0:
+ raise ValueError("For 'pixel_shuffle', the length of third to last dimension is not divisible"
+ "by `upscale_factor` squared.")
+ c = c // upscale_factor ** 2
+ input_perm = (pre + (c, upscale_factor, upscale_factor, h, w))
+ reshape = ops.Reshape()
+ x = reshape(x, input_perm)
+ input_perm = [i for i in range(length - 2)]
+ input_perm = input_perm + [length, length - 2, length + 1, length - 1]
+ input_perm = tuple(input_perm)
+ transpose = ops.Transpose()
+ x = transpose(x, input_perm)
+ x = reshape(x, (pre + (c, upscale_factor * h, upscale_factor * w)))
+ return x
+
+
+class PixelShuffle(nn.Cell):
+ r"""
+ Applies a pixelshuffle operation over an input signal composed of several input planes. This is useful for
+ implementiong efficient sub-pixel convolution with a stride of :math:`1/r`. For more details, refer to
+ `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
+ `_ .
+
+ Typically, the input is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape
+ :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor and * is zero or more batch dimensions.
+
+ Args:
+ upscale_factor (int): factor to increase spatial resolution by, and is a positive integer.
+
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` . The dimension of `x` is larger than 2, and
+ the length of third to last dimension can be divisible by `upscale_factor` squared.
+
+ Outputs:
+ - **output** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` .
+
+ Raises:
+ ValueError: If `upscale_factor` is not a positive integer.
+ ValueError: If the length of third to last dimension of `x` is not divisible by `upscale_factor` squared.
+ TypeError: If the dimension of `x` is less than 3.
+
+ Supported Platforms:
+ ``Ascend`` ``GPU`` ``CPU``
+ """
+ def __init__(self, upscale_factor):
+ super(PixelShuffle, self).__init__()
+ self.upscale_factor = upscale_factor
+
+ def construct(self, x):
+ return pixel_shuffle(x, self.upscale_factor)
+
+
+def pixel_unshuffle(x, downscale_factor):
+ r"""
+ Applies a pixel_unshuffle operation over an input signal composed of several input planes. For more details, refer
+ to `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
+ `_ .
+
+ Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape
+ :math:`(*, C \times r^2, H, W)` , where `r` is a downscale factor and `*` is zero or more batch dimensions.
+
+ Args:
+ x (Tensor): Tensor of shape :math:`(*, C, H \times r, W \times r)` . The dimension of `x` is larger than 2,
+ and the length of second to last dimension or last dimension can be divisible by `downscale_factor` .
+ downscale_factor (int): factor to decrease spatial resolution by, and is a positive integer.
+
+ Returns:
+ - **output** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` .
+
+ Raises:
+ ValueError: If `downscale_factor` is not a positive integer.
+ ValueError: If the length of second to last dimension or last dimension is not divisible by `downscale_factor` .
+ TypeError: If the dimension of `x` is less than 3.
+
+ Supported Platforms:
+ ``Ascend`` ``GPU`` ``CPU``
+ """
+ idx = x.shape
+ length = len(idx)
+ if length < 3:
+ raise TypeError(f"For pixel_unshuffle, the dimension of `x` should be larger than 2, but got {length}.")
+ pre = idx[:-3]
+ c, h, w = idx[-3:]
+ if h % downscale_factor != 0 or w % downscale_factor != 0:
+ raise ValueError("For 'pixel_unshuffle', the length of second to last 2 dimension should be divisible "
+ "by downscale_factor.")
+ h = h // downscale_factor
+ w = w // downscale_factor
+ input_perm = (pre + (c, h, downscale_factor, w, downscale_factor))
+ reshape = ops.Reshape()
+ x = reshape(x, input_perm)
+ input_perm = [i for i in range(length - 2)]
+ input_perm = input_perm + [length - 1, length + 1, length - 2, length]
+ input_perm = tuple(input_perm)
+ transpose = ops.Transpose()
+ x = transpose(x, input_perm)
+ x = reshape(x, (pre + (c * downscale_factor * downscale_factor, h, w)))
+ return x
+
+
+class PixelUnshuffle(nn.Cell):
+ r"""
+ Applies a pixelunshuffle operation over an input signal composed of several input planes. For more details, refer to
+ `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
+ `_ .
+
+ Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape
+ :math:`(*, C \times r^2, H, W)` , where r is a downscale factor and * is zero or more batch dimensions.
+
+ Args:
+ downscale_factor (int): factor to decrease spatial resolution by, and is a positive integer.
+
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` . The dimension of `x` is larger than
+ 2, and the length of second to last dimension or last dimension can be divisible by `downscale_factor` .
+
+ Outputs:
+ - **output** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` .
+
+ Raises:
+ ValueError: If `downscale_factor` is not a positive integer.
+ ValueError: If the length of second to last dimension or last dimension is not divisible by `downscale_factor` .
+ TypeError: If the dimension of `x` is less than 3.
+
+ Supported Platforms:
+ ``Ascend`` ``GPU`` ``CPU``
+ """
+ def __init__(self, downscale_factor):
+ super(PixelUnshuffle, self).__init__()
+ self.downscale_factor = downscale_factor
+
+ def construct(self, x):
+ return pixel_unshuffle(x, self.downscale_factor)
diff --git a/mindscience/data/__init__.py b/mindscience/data/__init__.py
index e50ddf9db637fe4190923ceb14df0b1664006f2b..bfd2d98bb35c5a2593b082fa426e88dcc78e9523 100644
--- a/mindscience/data/__init__.py
+++ b/mindscience/data/__init__.py
@@ -18,3 +18,6 @@ init
from .earth import *
from .flow import *
+
+
+__all__ = []
\ No newline at end of file
diff --git a/mindscience/data/flow/__init__.py b/mindscience/data/flow/__init__.py
index 705e9c2bc4bc2b34e6bfc28893965a46e210d971..4449101f6417e2dbbf008c5ec4717ed903089fea 100644
--- a/mindscience/data/flow/__init__.py
+++ b/mindscience/data/flow/__init__.py
@@ -15,4 +15,8 @@
"""
init
"""
+from .earth import *
+from .flow import *
+__all__.extend(earth.__all__)
+__all__.extend(flow.__all__)
\ No newline at end of file
diff --git a/mindscience/data/flow/geometry/__init__.py b/mindscience/data/flow/geometry/__init__.py
index 055a97648845b164b10d1fdb8615d25c6a79e44b..530edf88e1e9e9e19a4e210e73ec1e7e71d89400 100644
--- a/mindscience/data/flow/geometry/__init__.py
+++ b/mindscience/data/flow/geometry/__init__.py
@@ -45,6 +45,4 @@ __all__ = [
"CSGUnion",
"CSGXOR",
"generate_sampling_config",
-]
-
-__all__.sort()
+]
\ No newline at end of file
diff --git a/mindscience/distributed/__init__.py b/mindscience/distributed/__init__.py
index 705e9c2bc4bc2b34e6bfc28893965a46e210d971..69a14b29e1ced3fa627e5dada3f5f6ba239fdc1c 100644
--- a/mindscience/distributed/__init__.py
+++ b/mindscience/distributed/__init__.py
@@ -16,3 +16,4 @@
init
"""
+__all__ = []
\ No newline at end of file
diff --git a/mindscience/e3nn/__init__.py b/mindscience/e3nn/__init__.py
index 705e9c2bc4bc2b34e6bfc28893965a46e210d971..69a14b29e1ced3fa627e5dada3f5f6ba239fdc1c 100644
--- a/mindscience/e3nn/__init__.py
+++ b/mindscience/e3nn/__init__.py
@@ -16,3 +16,4 @@
init
"""
+__all__ = []
\ No newline at end of file
diff --git a/mindscience/gnn/__init__.py b/mindscience/gnn/__init__.py
index 705e9c2bc4bc2b34e6bfc28893965a46e210d971..69a14b29e1ced3fa627e5dada3f5f6ba239fdc1c 100644
--- a/mindscience/gnn/__init__.py
+++ b/mindscience/gnn/__init__.py
@@ -16,3 +16,4 @@
init
"""
+__all__ = []
\ No newline at end of file
diff --git a/mindscience/models/__init__.py b/mindscience/models/__init__.py
index 760c830fb13232806fe677474d8c873f9fe4fa96..e33e2ab90a363210a6b8313ba23b98088e46bc62 100644
--- a/mindscience/models/__init__.py
+++ b/mindscience/models/__init__.py
@@ -15,20 +15,15 @@
"""
init
"""
-from .demnet import *
-from .diffusion import *
+from .diffuser import *
from .GraphCast import *
from .layers import *
from .neural_operator import *
-from .pde import *
from .transformer import *
__all__ = []
-__all__.extend(demnet.__all__)
-__all__.extend(dgmr.__all__)
-__all__.extend(diffusion.__all__)
+__all__.extend(diffuser.__all__)
__all__.extend(GraphCast.__all__)
__all__.extend(layers.__all__)
__all__.extend(neural_operator.__all__)
-__all__.extend(pde.__all__)
__all__.extend(transformer.__all__)
diff --git a/mindscience/models/diffusion/__init__.py b/mindscience/models/diffuser/__init__.py
similarity index 84%
rename from mindscience/models/diffusion/__init__.py
rename to mindscience/models/diffuser/__init__.py
index 04ad4035d5d130b5f253c916a4938183f233287c..5ae9561763d3fc7d8be39b4a2c0a54af5cc14fb4 100644
--- a/mindscience/models/diffusion/__init__.py
+++ b/mindscience/models/diffuser/__init__.py
@@ -17,3 +17,5 @@ init
"""
from .diffusion import DDPMPipeline, DDPMScheduler, DDIMPipeline, DDIMScheduler, DiffusionTrainer
from .diffusion_transformer import DiffusionTransformer, ConditionDiffusionTransformer
+
+__all__ = ["DDPMPipeline", "DDPMScheduler", "DDIMPipeline", "DDIMScheduler", "DiffusionTrainer", "DiffusionTransformer", "ConditionDiffusionTransformer"]
\ No newline at end of file
diff --git a/mindscience/models/diffusion/diffusion.py b/mindscience/models/diffuser/diffusion.py
similarity index 100%
rename from mindscience/models/diffusion/diffusion.py
rename to mindscience/models/diffuser/diffusion.py
diff --git a/mindscience/models/diffusion/diffusion_transformer.py b/mindscience/models/diffuser/diffusion_transformer.py
similarity index 99%
rename from mindscience/models/diffusion/diffusion_transformer.py
rename to mindscience/models/diffuser/diffusion_transformer.py
index 98e94bd49713851705ded2918731d9e1bbf61330..c99553134e83456db952d3fcd3622cdbf5175100 100644
--- a/mindscience/models/diffusion/diffusion_transformer.py
+++ b/mindscience/models/diffuser/diffusion_transformer.py
@@ -19,7 +19,7 @@ import math
import numpy as np
from mindspore import nn, ops, Tensor
from mindspore import dtype as mstype
-from mindflow.cell import AttentionBlock
+from mindscience.models import TransformerBlock
class Mlp(nn.Cell):
@@ -67,7 +67,7 @@ class Transformer(nn.Cell):
self.hidden_channels = hidden_channels
self.layers = layers
self.blocks = nn.CellList([
- AttentionBlock(
+ TransformerBlock(
in_channels=hidden_channels,
num_heads=heads,
drop_mode="dropout",
diff --git a/mindscience/models/layers/__init__.py b/mindscience/models/layers/__init__.py
index 45b505eb8ec5ba3c7d805b58e512f52923191a37..929173e0fd7af51a269eb6311a6e5716d00f4700 100644
--- a/mindscience/models/layers/__init__.py
+++ b/mindscience/models/layers/__init__.py
@@ -17,6 +17,6 @@ init
"""
from .activation import get_activation
from .basic_block import LinearBlock, ResBlock, InputScale, FCSequential, MultiScaleFCSequential, DropPath
-from .attention import Attention, MultiHeadAttention, AttentionBlock
-from .vit import ViT
from .unet2d import UNet2D
+
+__all__ = ["get_activation", "LinearBlock", "ResBlock", "InputScale", "FCSequential", "MultiScaleFCSequential", "DropPath", "UNet2D"]
\ No newline at end of file
diff --git a/mindscience/models/layers/basic_block.py b/mindscience/models/layers/basic_block.py
index f3f5c34ec14c443ef1e603751c2b5841909a6489..b26e28199dbae5828538e05a2cb54c97f36df7b3 100644
--- a/mindscience/models/layers/basic_block.py
+++ b/mindscience/models/layers/basic_block.py
@@ -25,7 +25,7 @@ from mindspore import Tensor, Parameter
from mindspore.ops.primitive import constexpr
from .activation import get_activation
-from ..utils.check_func import check_param_type
+from ...utils.check_func import check_param_type
__all__ = ['LinearBlock', 'ResBlock', 'InputScale', 'FCSequential',
'MultiScaleFCSequential', 'DropPath']
diff --git a/mindscience/models/layers/unet2d.py b/mindscience/models/layers/unet2d.py
index e8a6026875c396303edd7ba31a60a859cb3be43c..8bc0e77207eb69ff080da0f3204fd04e31ce3277 100644
--- a/mindscience/models/layers/unet2d.py
+++ b/mindscience/models/layers/unet2d.py
@@ -20,7 +20,7 @@ import mindspore.ops as ops
from mindspore.ops import operations as P
from .activation import get_activation
-from ..utils.check_func import check_param_type
+from ...utils.check_func import check_param_type
class DoubleConv(nn.Cell):
diff --git a/mindscience/models/model_factory.py b/mindscience/models/model_factory.py
index 8db9ccf2023c595c11fb7972d6e4ad883ab0526a..932f07365d2f403737e1e29c31e15131379c685e 100644
--- a/mindscience/models/model_factory.py
+++ b/mindscience/models/model_factory.py
@@ -17,7 +17,7 @@ init
"""
import os
-from mindspore import ops, Tensor, nn, load_param_into_net, load_checkpoint
+from mindspore import ops, Tensor, nn, load_param_into_net, load_checkpoint, save_checkpoint
class SciModule(nn.Cell):
@@ -28,10 +28,44 @@ class SciModule(nn.Cell):
assert ckpt_file.endswith('.ckpt') and os.path.exists(ckpt_file)
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(self, param_dict)
+ print(f"Load checkpoint from {ckpt_file}")
- def set_grad(self):
- super.set_grad()
+
+ def save(self, ckpt_file):
+ """保存模型参数到 checkpoint 文件"""
+ os.makedirs(os.path.dirname(ckpt_file), exist_ok=True)
+ save_checkpoint(self, ckpt_file)
+
+ def set_grad(self, special_layer_patterns=None):
+ """对符合特殊名称模式的网络层设置requires_grad"""
+ super().set_grad()
+ # 默认匹配包含"encoder"或"decoder"的层
+ special_layer_patterns = special_layer_patterns or ["encoder", "decoder"]
+ for param in self.parameters():
+ param.requires_grad = any(pattern in param.name for pattern in special_layer_patterns)
@property
def num_params(self):
- return sum([i.numel() for i in self.trainable_params()])
\ No newline at end of file
+ return sum([i.numel() for i in self.trainable_params()])
+
+ def set_precision(self, precision=mindspore.float32, special_layer_patterns=None, blacklist_patterns=None, whitelist_patterns=None):
+ # 设置默认黑白名单
+ blacklist = blacklist_patterns or ["softmax", "layernorm", "batchnorm"]
+ whitelist = whitelist_patterns or []
+
+ for name, cell in self.cells_and_names():
+ # 黑名单检查:强制float32
+ if any(pattern in name.lower() for pattern in blacklist):
+ cell.to_float(mindspore.float32)
+ for param in cell.get_parameters():
+ param.set_data(param.data.astype(mindspore.float32))
+ # 白名单检查:应用用户指定精度
+ elif any(pattern in name.lower() for pattern in whitelist):
+ cell.to_float(precision)
+ for param in cell.get_parameters():
+ param.set_data(param.data.astype(precision))
+ # 其他层:按special_layer_patterns处理
+ elif special_layer_patterns and any(pattern in name.lower() for pattern in special_layer_patterns):
+ cell.to_float(precision)
+ for param in cell.get_parameters():
+ param.set_data(param.data.astype(precision))
\ No newline at end of file
diff --git a/mindscience/models/neural_operator/__init__.py b/mindscience/models/neural_operator/__init__.py
index b7532f16dce56fd66336e3c02ec67899f8a21ced..8910560ef21b63aad5d3fde67c3773137bfedbfe 100644
--- a/mindscience/models/neural_operator/__init__.py
+++ b/mindscience/models/neural_operator/__init__.py
@@ -14,11 +14,14 @@
# ============================================================================
"""init"""
from .fno import FNOBlocks, FNO1D, FNO2D, FNO3D
+from .ffno import FFNOBlocks, FFNO1D, FFNO2D, FFNO3D
from .kno1d import KNO1D
from .kno2d import KNO2D
from .pdenet import PDENet
from .percnn import PeRCNN
from .sno import SNO, SNO1D, SNO2D, SNO3D
-__all__ = ["FNOBlocks", "FNO1D", "FNO2D", "FNO3D", "KNO1D", "KNO2D", "PDENet", "PeRCNN",
+__all__ = ["FNOBlocks", "FNO1D", "FNO2D", "FNO3D",
+ "FFNOBlocks", "FFNO1D", "FFNO2D", "FFNO3D",
+ "KNO1D", "KNO2D", "PDENet", "PeRCNN",
"SNO", "SNO1D", "SNO2D", "SNO3D"]
diff --git a/mindscience/models/neural_operator/ffno.py b/mindscience/models/neural_operator/ffno.py
new file mode 100644
index 0000000000000000000000000000000000000000..35a9116509ad5ecae6da74167e06e335f5866356
--- /dev/null
+++ b/mindscience/models/neural_operator/ffno.py
@@ -0,0 +1,804 @@
+''''
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+'''
+# pylint: disable=W0235
+
+from mindspore import nn, ops, Tensor, Parameter, ParameterTuple, mint
+from mindspore.common.initializer import XavierNormal, initializer
+import mindspore.common.dtype as mstype
+
+from .ffno_sp import SpectralConv1d, SpectralConv2d, SpectralConv3d
+from ...common.math import get_grid_1d, get_grid_2d, get_grid_3d
+from ...utils.check_func import check_param_type
+
+
+class FFNOBlocks(nn.Cell):
+ r"""
+ The FFNOBlock, which usually accompanied by a Lifting Layer ahead and a Projection Layer behind,
+ is a part of Factorized Fourier Neural Operator. It contains a Factorized Fourier Layer. The details can be found
+ in `A. Tran, A. Mathews, et. al: FACTORIZED FOURIER NEURAL OPERATORS `_.
+
+ Args:
+ in_channels (int): The number of channels in the input space.
+ out_channels (int): The number of channels in the output space.
+ n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer.
+ resolutions (Union[int, list(int)]): The resolutions of the input tensor.
+ factor (int): The number of neurons in the hidden layer of a feedforward network. Default: ``1``.
+ n_ff_layers (int): The number of layers (hidden layers) in the feedforward neural network. Default: ``2``.
+ ff_weight_norm (bool): Whether to do weight normalization in feedforward or not. Used as a reserved function
+ interface, the weight normalization is not supported in feedforward. Default: ``False``.
+ layer_norm (bool): Whether to do layer normalization in feedforward or not. Default: ``True``.
+ dropout (float): The value of percent be dropped when applying dropout regularization. Default: ``0.0``.
+ r_padding (int): The number used to pad a tensor on the right in a certain dimension. Pad the domain if
+ input is non-periodic. Default: ``0``.
+ use_fork (bool): Whether to perform forecasting or not. Default: ``False``.
+ forecast_ff (Feedforward): The feedforward network of generating "backcast" output. Default: ``None``.
+ backcast_ff (Feedforward): The feedforward network of generating "forecast" output. Default: ``None``.
+ fourier_weight (ParameterTuple[Parmemter]): The fourier weight for transforming data in the frequency
+ domain, with a ParameterTuple of Parmemter with a length of 2N.
+
+ - Even indices (0, 2, 4, ...) represent the real parts of the complex parmemter.
+ - Odd indices (1, 3, 5, ...) represent the imaginary parts of the complex parmemter.
+ - Default: ``None``, meaning no data is provided.
+ dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConv. Default: ``mstype.float32``.
+ ffno_compute_dtype (dtype.Number): The computation type of MLP in ffno skip. Default: ``mstype.float16``.
+ Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for the GPU backend,
+ mstype.float16 is recommended for the Ascend backend.
+
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(batch\_size, in\_channels, resolution)`.
+
+ Outputs:
+ Tensor, the output of this FFNOBlocks.
+
+ - **output** (Tensor) -Tensor of shape :math:`(batch\_size, out\_channels, resolution)`.
+
+ Raises:
+ TypeError: If `in_channels` is not an int.
+ TypeError: If `out_channels` is not an int.
+ TypeError: If `factor` is not an int.
+ TypeError: If `n_ff_layers` is not an int.
+ TypeError: If `ff_weight_norm` is not a Boolean value.
+ ValueError: If `ff_weight_norm` is not ``False``.
+ TypeError: If `layer_norm` is not a Boolean value.
+ TypeError: If `dropout` is not a float.
+ TypeError: If `r_padding` is not an int.
+ TypeError: If `use_fork` is not a Boolean value.
+
+ Supported Platforms:
+ ``Ascend``
+
+ Examples:`
+ >>> import numpy as np
+ >>> from mindspore import Tensor
+ >>> import mindspore.common.dtype as mstype
+ >>> from mindflow.cell.neural_operators import FFNOBlocks
+ >>> data = Tensor(np.ones([2, 128, 128, 2]), mstype.float32)
+ >>> net = FFNOBlocks(in_channels=2, out_channels=2, n_modes=[20, 20], resolutions=[128, 128])
+ >>> out0, out1 = net(data)
+ >>> print(data.shape, out0.shape, out1.shape)
+ (2, 128, 128, 2) (2, 128, 128, 2) (2, 128, 128, 2)
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ n_modes,
+ resolutions,
+ factor=1,
+ n_ff_layers=2,
+ ff_weight_norm=False,
+ layer_norm=True,
+ dropout=0.0,
+ r_padding=0,
+ use_fork=False,
+ forecast_ff=None,
+ backcast_ff=None,
+ fourier_weight=None,
+ dft_compute_dtype=mstype.float32,
+ ffno_compute_dtype=mstype.float32
+ ):
+ super().__init__()
+ check_param_type(in_channels, "in_channels", data_type=int)
+ check_param_type(out_channels, "out_channels", data_type=int)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.n_modes, self.resolutions = validate_and_expand_dimensions(
+ 1, n_modes, resolutions, False)
+
+ check_param_type(factor, "factor", data_type=int)
+ check_param_type(n_ff_layers, "n_ff_layers", data_type=int)
+ check_param_type(ff_weight_norm, "ff_weight_norm", data_type=bool)
+ check_param_type(layer_norm, "layer_norm", data_type=bool)
+ check_param_type(dropout, "dropout", data_type=float)
+ check_param_type(r_padding, 'r_padding', data_type=int)
+
+ if ff_weight_norm:
+ raise ValueError(
+ f"The weight normalization is not supported in feedforward\
+ but got value of ff_weight_norm {ff_weight_norm}")
+
+ if r_padding < 0:
+ raise ValueError(
+ f"The right padding value cannot be negative\
+ but got value of r_padding {r_padding}")
+
+ check_param_type(use_fork, "use_fork", data_type=bool)
+ self.factor = factor
+ self.ff_weight_norm = ff_weight_norm
+ self.n_ff_layers = n_ff_layers
+ self.layer_norm = layer_norm
+ self.dropout = dropout
+ self.r_padding = r_padding
+ self.use_fork = use_fork
+ self.forecast_ff = forecast_ff
+ self.backcast_ff = backcast_ff
+ self.fourier_weight = fourier_weight
+ self.dft_compute_dtype = dft_compute_dtype
+ self.ffno_compute_dtype = ffno_compute_dtype
+
+ if len(self.resolutions) == 1:
+ spectral_conv = SpectralConv1d
+ elif len(self.resolutions) == 2:
+ spectral_conv = SpectralConv2d
+ elif len(self.resolutions) == 3:
+ spectral_conv = SpectralConv3d
+ else:
+ raise ValueError(
+ f"The length of input resolutions dimensions should be in [1, 2, 3], but got: {len(self.resolutions)}")
+
+ self._convs = spectral_conv(self.in_channels,
+ self.out_channels,
+ self.n_modes,
+ self.resolutions,
+ forecast_ff=self.forecast_ff,
+ backcast_ff=self.backcast_ff,
+ fourier_weight=self.fourier_weight,
+ factor=self.factor,
+ ff_weight_norm=self.ff_weight_norm,
+ n_ff_layers=self.n_ff_layers,
+ layer_norm=self.layer_norm,
+ use_fork=self.use_fork,
+ dropout=self.dropout,
+ r_padding=self.r_padding,
+ compute_dtype=self.dft_compute_dtype,
+ filter_mode='full')
+
+ def construct(self, x: Tensor):
+ b, _ = self._convs(x)
+ x = ops.add(x, b)
+ return x, b
+
+
+def validate_and_expand_dimensions(dim, n_modes, resolutions, is_validate_dim=True):
+ """validate and expand the dimension of inputs"""
+ if isinstance(n_modes, int):
+ n_modes = [n_modes] * dim
+ if isinstance(resolutions, int):
+ resolutions = [resolutions] * dim
+
+ n_modes_num = len(n_modes)
+ resolutions_num = len(resolutions)
+
+ if is_validate_dim:
+ if n_modes_num != dim:
+ raise ValueError(
+ f"The dimension of n_modes should be equal to {dim} when using FFNO{dim}D\
+ but got dimension of n_modes {n_modes_num}")
+ if resolutions_num != dim:
+ raise ValueError(
+ f"The dimension of resolutions should be equal to {dim} when using FFNO{dim}D\
+ but got dimension of resolutions {resolutions_num}")
+ if n_modes_num != resolutions_num:
+ raise ValueError(
+ f"The dimension of n_modes should be equal to that of resolutions\
+ but got dimension of n_modes {n_modes_num} and dimension of resolutions {resolutions_num}")
+
+ return n_modes, resolutions
+
+
+class FFNO(nn.Cell):
+ r"""
+ The FFNO base class, which usually contains a Lifting Layer, a Factorized Fourier Block Layer and a Projection
+ Layer. The details can be found in
+ `A. Tran, A. Mathews, et. al: FACTORIZED FOURIER NEURAL OPERATORS `_.
+
+ Args:
+ in_channels (int): The number of channels in the input space.
+ out_channels (int): The number of channels in the output space.
+ n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer.
+ resolutions (Union[int, list(int)]): The resolutions of the input tensor.
+ hidden_channels (int): The number of channels of the FNOBlock input and output. Default: ``20``.
+ lifting_channels (int): The number of channels of the lifting layer mid channels. Default: None.
+ projection_channels (int): The number of channels of the projection layer mid channels. Default: ``128``.
+ factor (int): The number of neurons in the hidden layer of a feedforward network. Default: ``1``.
+ n_layers (int): The number that Fourier Layer nests. Default: ``4``.
+ n_ff_layers (int): The number of layers (hidden layers) in the feedforward neural network. Default: ``2``.
+ ff_weight_norm (bool): Whether to do weight normalization in feedforward or not. Used as a reserved function
+ interface, the weight normalization is not supported in feedforward. Default: ``False``.
+ layer_norm (bool): Whether to do layer normalization in feedforward or not. Default: ``True``.
+ share_weight (bool): Whether to share weights between SpectralConv layers or not. Default: ``False``.
+ r_padding (int): The number used to pad a tensor on the right in a certain dimension. Pad the domain if
+ input is non-periodic. Default: ``0``.
+ data_format (str): The input data channel sequence. Default: ``channels_last``.
+ positional_embedding (bool): Whether to embed positional information or not. Default: ``True``.
+ dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``.
+ ffno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``.
+ Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for
+ the GPU backend, mstype.float16 is recommended for the Ascend backend.
+
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`.
+
+ Outputs:
+ Tensor, the output of this FNOBlocks.
+
+ - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, out\_channels)`.
+
+ Raises:
+ TypeError: If `in_channels` is not an int.
+ TypeError: If `out_channels` is not an int.
+ TypeError: If `hidden_channels` is not an int.
+ TypeError: If `lifting_channels` is not an int.
+ TypeError: If `projection_channels` is not an int.
+ TypeError: If `factor` is not an int.
+ TypeError: If `n_layers` is not an int.
+ TypeError: If `n_ff_layers` is not an int.
+ TypeError: If `ff_weight_norm` is not a Boolean value.
+ ValueError: If `ff_weight_norm` is not ``False``.
+ TypeError: If `layer_norm` is not a Boolean value.
+ TypeError: If `share_weight` is not a Boolean value.
+ TypeError: If `r_padding` is not an int.
+ TypeError: If `data_format` is not a str.
+ TypeError: If `positional_embedding` is not a bool.
+
+ Supported Platforms:
+ ``Ascend``
+
+ Examples:
+ >>> import numpy as np
+ >>> from mindspore import Tensor
+ >>> import mindspore.common.dtype as mstype
+ >>> from mindflow.cell.neural_operators.ffno import FFNO
+ >>> data = Tensor(np.ones([2, 128, 128, 2]), mstype.float32)
+ >>> net = FFNO(in_channels=2, out_channels=2, n_modes=[20, 20], resolutions=[128, 128])
+ >>> out = net(data)
+ >>> print(data.shape, out.shape)
+ (2, 128, 128, 2) (2, 128, 128, 2)
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ n_modes,
+ resolutions,
+ hidden_channels=20,
+ lifting_channels=None,
+ projection_channels=128,
+ factor=1,
+ n_layers=4,
+ n_ff_layers=2,
+ ff_weight_norm=False,
+ layer_norm=True,
+ share_weight=False,
+ r_padding=0,
+ data_format="channels_last",
+ positional_embedding=True,
+ dft_compute_dtype=mstype.float32,
+ ffno_compute_dtype=mstype.float16
+ ):
+ super().__init__()
+ check_param_type(in_channels, "in_channels", data_type=int, exclude_type=bool)
+ check_param_type(out_channels, "out_channels", data_type=int, exclude_type=bool)
+ check_param_type(hidden_channels, "hidden_channels", data_type=int, exclude_type=bool)
+ check_param_type(factor, "factor", data_type=int, exclude_type=bool)
+ check_param_type(n_layers, "n_layers", data_type=int, exclude_type=bool)
+ check_param_type(n_ff_layers, "n_ff_layers", data_type=int, exclude_type=bool)
+ check_param_type(ff_weight_norm, "ff_weight_norm", data_type=bool, exclude_type=str)
+ check_param_type(layer_norm, "layer_norm", data_type=bool, exclude_type=str)
+ check_param_type(share_weight, "share_weight", data_type=bool, exclude_type=str)
+ check_param_type(r_padding, "r_padding", data_type=int, exclude_type=bool)
+ check_param_type(data_format, "data_format", data_type=str, exclude_type=bool)
+ check_param_type(positional_embedding, "positional_embedding", data_type=bool, exclude_type=str)
+
+ if ff_weight_norm:
+ raise ValueError(
+ f"The weight normalization is not supported in feedforward\
+ but got value of ff_weight_norm {ff_weight_norm}")
+
+ if r_padding < 0:
+ raise ValueError(f"The right padding value cannot be negative but got value of r_padding {r_padding}")
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.lifting_channels = lifting_channels
+ self.projection_channels = projection_channels
+ self.n_modes, self.resolutions = validate_and_expand_dimensions(
+ 1, n_modes, resolutions, False)
+ self.n_layers = n_layers
+ self.r_padding = r_padding
+ self.data_format = data_format
+ self.positional_embedding = positional_embedding
+ if self.positional_embedding:
+ self.in_channels += len(self.resolutions)
+ self.dft_compute_dtype = dft_compute_dtype
+ self.ffno_compute_dtype = ffno_compute_dtype
+ self._concat = ops.Concat(axis=-1)
+ self._positional_embedding = self._transpose(len(self.resolutions))
+ self._padding = self._pad(len(self.resolutions))
+ if self.lifting_channels:
+ self._lifting = nn.SequentialCell([
+ nn.Dense(self.in_channels, self.lifting_channels, has_bias=True).to_float(self.ffno_compute_dtype),
+ nn.Dense(self.lifting_channels, self.hidden_channels, has_bias=True).to_float(self.ffno_compute_dtype)])
+ else:
+ self._lifting = nn.SequentialCell(
+ nn.Dense(self.in_channels, self.hidden_channels, has_bias=True).to_float(self.ffno_compute_dtype)
+ )
+
+ self.fourier_weight = None
+ if share_weight:
+ param_list = []
+ for i, n_mode in enumerate(self.n_modes):
+ weight_shape = [hidden_channels, hidden_channels, n_mode]
+
+ w_re = Parameter(initializer(XavierNormal(), weight_shape, mstype.float32), name=f'base_w_re_{i}',
+ requires_grad=True)
+ w_im = Parameter(initializer(XavierNormal(), weight_shape, mstype.float32), name=f'base_w_im_{i}',
+ requires_grad=True)
+
+ param_list.append(w_re)
+ param_list.append(w_im)
+
+ self.fourier_weight = ParameterTuple([param for param in param_list])
+
+ self.factor = factor
+ self.ff_weight_norm = ff_weight_norm
+ self.n_ff_layers = n_ff_layers
+ self.layer_norm = layer_norm
+
+ self._ffno_blocks = nn.CellList([FFNOBlocks(in_channels=self.hidden_channels,
+ out_channels=self.hidden_channels,
+ n_modes=self.n_modes,
+ resolutions=self.resolutions,
+ factor=self.factor,
+ n_ff_layers=self.n_ff_layers,
+ ff_weight_norm=self.ff_weight_norm,
+ layer_norm=self.layer_norm,
+ dropout=0.0, r_padding=self.r_padding,
+ use_fork=False, forecast_ff=None, backcast_ff=None,
+ fourier_weight=self.fourier_weight,
+ dft_compute_dtype=self.dft_compute_dtype
+ ) for _ in range(self.n_layers)])
+
+ if self.projection_channels:
+ self._projection = nn.SequentialCell([
+ nn.Dense(self.hidden_channels, self.projection_channels, has_bias=True).to_float(
+ self.ffno_compute_dtype),
+ nn.Dense(self.projection_channels, self.out_channels, has_bias=True).to_float(
+ self.ffno_compute_dtype)
+ ])
+ else:
+ self._projection = nn.SequentialCell(
+ nn.Dense(self.hidden_channels, self.out_channels, has_bias=True).to_float(
+ self.ffno_compute_dtype))
+
+ def construct(self, x: Tensor):
+ """construct"""
+ batch_size = x.shape[0]
+ grid = mint.repeat_interleave(self._positional_embedding.astype(x.dtype), repeats=batch_size, dim=0)
+
+ if self.data_format != "channels_last":
+ x = ops.movedim(x, 1, -1)
+
+ if self.positional_embedding:
+ x = self._concat((x, grid))
+
+ x = self._lifting(x)
+ if self.r_padding != 0:
+ x = ops.movedim(x, -1, 1)
+ x = ops.pad(x, self._padding)
+ x = ops.movedim(x, 1, -1)
+
+ b = Tensor(0, dtype=mstype.float32)
+ for block in self._ffno_blocks:
+ x, b = block(x)
+
+ if self.r_padding != 0:
+ b = self._remove_padding(len(self.resolutions), b)
+
+ x = self._projection(b)
+
+ if self.data_format != "channels_last":
+ x = ops.movedim(x, -1, 1)
+
+ return x
+
+ def _transpose(self, n_dim):
+ """transpose tensor"""
+ if n_dim == 1:
+ positional_embedding = Tensor(get_grid_1d(resolution=self.resolutions))
+ elif n_dim == 2:
+ positional_embedding = Tensor(get_grid_2d(resolution=self.resolutions))
+ elif n_dim == 3:
+ positional_embedding = Tensor(get_grid_3d(resolution=self.resolutions))
+ else:
+ raise ValueError(f"The length of input resolutions dimensions should be in [1, 2, 3], but got: {n_dim}")
+ return positional_embedding
+
+ def _pad(self, n_dim):
+ """pad the domain if input is non-periodic"""
+ if not n_dim in {1, 2, 3}:
+ raise ValueError(f"The length of input resolutions dimensions should be in [1, 2, 3], but got: {n_dim}")
+ return n_dim * [0, self.r_padding]
+
+ def _remove_padding(self, n_dim, b_input):
+ """remove pad domain"""
+ if n_dim == 1:
+ b = b_input[..., :-self.r_padding, :]
+ elif n_dim == 2:
+ b = b_input[..., :-self.r_padding, :-self.r_padding, :]
+ elif n_dim == 3:
+ b = b_input[..., :-self.r_padding, :-self.r_padding, :-self.r_padding, :]
+ else:
+ raise ValueError(f"The length of input resolutions dimensions should be in [1, 2, 3], but got: {n_dim}")
+ return b
+
+
+class FFNO1D(FFNO):
+ r"""
+ The 1D Factorized Fourier Neural Operator, which usually contains a Lifting Layer,
+ a Factorized Fourier Block Layer and a Projection Layer. The details can be found in
+ `A. Tran, A. Mathews, et. al: FACTORIZED FOURIER NEURAL OPERATORS `_.
+
+ Args:
+ in_channels (int): The number of channels in the input space.
+ out_channels (int): The number of channels in the output space.
+ n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer.
+ resolutions (Union[int, list(int)]): The resolutions of the input tensor.
+ hidden_channels (int): The number of channels of the FNOBlock input and output. Default: ``20``.
+ lifting_channels (int): The number of channels of the lifting layer mid channels. Default: None.
+ projection_channels (int): The number of channels of the projection layer mid channels. Default: ``128``.
+ factor (int): The number of neurons in the hidden layer of a feedforward network. Default: ``1``.
+ n_layers (int): The number that Fourier Layer nests. Default: ``4``.
+ n_ff_layers (int): The number of layers (hidden layers) in the feedforward neural network. Default: ``2``.
+ ff_weight_norm (bool): Whether to do weight normalization in feedforward or not. Used as a reserved function
+ interface, the weight normalization is not supported in feedforward. Default: ``False``.
+ layer_norm (bool): Whether to do layer normalization in feedforward or not. Default: ``True``.
+ share_weight (bool): Whether to share weights between SpectralConv layers or not. Default: ``False``.
+ r_padding (int): The number used to pad a tensor on the right in a certain dimension. Default: ``0``.
+ data_format (str): The input data channel sequence. Default: ``channels_last``.
+ positional_embedding (bool): Whether to embed positional information or not. Default: ``True``.
+ dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``.
+ ffno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``.
+ Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for
+ the GPU backend, mstype.float16 is recommended for the Ascend backend.
+
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`.
+
+ Outputs:
+ Tensor, the output of this FNOBlocks.
+
+ - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, out\_channels)`.
+
+ Raises:
+ TypeError: If `in_channels` is not an int.
+ TypeError: If `out_channels` is not an int.
+ TypeError: If `hidden_channels` is not an int.
+ TypeError: If `lifting_channels` is not an int.
+ TypeError: If `projection_channels` is not an int.
+ TypeError: If `factor` is not an int.
+ TypeError: If `n_layers` is not an int.
+ TypeError: If `n_ff_layers` is not an int.
+ TypeError: If `ff_weight_norm` is not a Boolean value.
+ ValueError: If `ff_weight_norm` is not ``False``.
+ TypeError: If `layer_norm` is not a Boolean value.
+ TypeError: If `share_weight` is not a Boolean value.
+ TypeError: If `r_padding` is not an int.
+ TypeError: If `data_format` is not a str.
+ TypeError: If `positional_embedding` is not a bool.
+
+ Supported Platforms:
+ ``Ascend``
+
+ Examples:
+ >>> import numpy as np
+ >>> import mindspore
+ >>> import mindflow
+ >>> from mindspore import Tensor
+ >>> import mindspore.common.dtype as mstype
+ >>> from mindflow.cell import FFNO1D
+ >>> data = Tensor(np.ones([2, 128, 3]), mstype.float32)
+ >>> net = FFNO1D(in_channels=3, out_channels=3, n_modes=[20], resolutions=[128])
+ >>> out = net(data)
+ >>> print(data.shape, out.shape)
+ (2, 128, 3) (2, 128, 3)
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ n_modes,
+ resolutions,
+ hidden_channels=20,
+ lifting_channels=None,
+ projection_channels=128,
+ factor=1,
+ n_layers=4,
+ n_ff_layers=2,
+ ff_weight_norm=False,
+ layer_norm=True,
+ share_weight=False,
+ r_padding=0,
+ data_format="channels_last",
+ positional_embedding=True,
+ dft_compute_dtype=mstype.float32,
+ ffno_compute_dtype=mstype.float16
+ ):
+ n_modes, resolutions = validate_and_expand_dimensions(1, n_modes, resolutions)
+ super().__init__(
+ in_channels,
+ out_channels,
+ n_modes,
+ resolutions,
+ hidden_channels,
+ lifting_channels,
+ projection_channels,
+ factor,
+ n_layers,
+ n_ff_layers,
+ ff_weight_norm,
+ layer_norm,
+ share_weight,
+ r_padding,
+ data_format,
+ positional_embedding,
+ dft_compute_dtype,
+ ffno_compute_dtype
+ )
+
+
+class FFNO2D(FFNO):
+ r"""
+ The 2D Factorized Fourier Neural Operator, which usually contains a Lifting Layer,
+ a Factorized Fourier Block Layer and a Projection Layer. The details can be found in
+ `A. Tran, A. Mathews, et. al: FACTORIZED FOURIER NEURAL OPERATORS `_.
+
+ Args:
+ in_channels (int): The number of channels in the input space.
+ out_channels (int): The number of channels in the output space.
+ n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer.
+ resolutions (Union[int, list(int)]): The resolutions of the input tensor.
+ hidden_channels (int): The number of channels of the FNOBlock input and output. Default: ``20``.
+ lifting_channels (int): The number of channels of the lifting layer mid channels. Default: None.
+ projection_channels (int): The number of channels of the projection layer mid channels. Default: ``128``.
+ factor (int): The number of neurons in the hidden layer of a feedforward network. Default: ``1``.
+ n_layers (int): The number that Fourier Layer nests. Default: ``4``.
+ n_ff_layers (int): The number of layers (hidden layers) in the feedforward neural network. Default: ``2``.
+ ff_weight_norm (bool): Whether to do weight normalization in feedforward or not. Used as a reserved function
+ interface, the weight normalization is not supported in feedforward. Default: ``False``.
+ layer_norm (bool): Whether to do layer normalization in feedforward or not. Default: ``True``.
+ share_weight (bool): Whether to share weights between SpectralConv layers or not. Default: ``False``.
+ r_padding (int): The number used to pad a tensor on the right in a certain dimension. Default: ``0``.
+ data_format (str): The input data channel sequence. Default: ``channels_last``.
+ positional_embedding (bool): Whether to embed positional information or not. Default: ``True``.
+ dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``.
+ ffno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``.
+ Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for
+ the GPU backend, mstype.float16 is recommended for the Ascend backend.
+
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`.
+
+ Outputs:
+ Tensor, the output of this FNOBlocks.
+
+ - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, out\_channels)`.
+
+ Raises:
+ TypeError: If `in_channels` is not an int.
+ TypeError: If `out_channels` is not an int.
+ TypeError: If `hidden_channels` is not an int.
+ TypeError: If `lifting_channels` is not an int.
+ TypeError: If `projection_channels` is not an int.
+ TypeError: If `factor` is not an int.
+ TypeError: If `n_layers` is not an int.
+ TypeError: If `n_ff_layers` is not an int.
+ TypeError: If `ff_weight_norm` is not a Boolean value.
+ ValueError: If `ff_weight_norm` is not ``False``.
+ TypeError: If `layer_norm` is not a Boolean value.
+ TypeError: If `share_weight` is not a Boolean value.
+ TypeError: If `r_padding` is not an int.
+ TypeError: If `data_format` is not a str.
+ TypeError: If `positional_embedding` is not a bool.
+
+ Supported Platforms:
+ ``Ascend``
+
+ Examples:
+ >>> import numpy as np
+ >>> import mindspore
+ >>> import mindflow
+ >>> from mindspore import Tensor
+ >>> import mindspore.common.dtype as mstype
+ >>> from mindflow.cell import FFNO2D
+ >>> data = Tensor(np.ones([2, 128, 128, 3]), mstype.float32)
+ >>> net = FFNO2D(in_channels=3, out_channels=3, n_modes=[20, 20], resolutions=[128, 128])
+ >>> out = net(data)
+ >>> print(data.shape, out.shape)
+ (2, 128, 128, 3) (2, 128, 128, 3)
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ n_modes,
+ resolutions,
+ hidden_channels=20,
+ lifting_channels=None,
+ projection_channels=128,
+ factor=1,
+ n_layers=4,
+ n_ff_layers=2,
+ ff_weight_norm=False,
+ layer_norm=True,
+ share_weight=False,
+ r_padding=0,
+ data_format="channels_last",
+ positional_embedding=True,
+ dft_compute_dtype=mstype.float32,
+ ffno_compute_dtype=mstype.float16
+ ):
+ n_modes, resolutions = validate_and_expand_dimensions(2, n_modes, resolutions)
+ super().__init__(
+ in_channels,
+ out_channels,
+ n_modes,
+ resolutions,
+ hidden_channels,
+ lifting_channels,
+ projection_channels,
+ factor,
+ n_layers,
+ n_ff_layers,
+ ff_weight_norm,
+ layer_norm,
+ share_weight,
+ r_padding,
+ data_format,
+ positional_embedding,
+ dft_compute_dtype,
+ ffno_compute_dtype
+ )
+
+
+class FFNO3D(FFNO):
+ r"""
+ The 3D Factorized Fourier Neural Operator, which usually contains a Lifting Layer,
+ a Factorized Fourier Block Layer and a Projection Layer. The details can be found in
+ `A. Tran, A. Mathews, et. al: FACTORIZED FOURIER NEURAL OPERATORS `_.
+
+ Args:
+ in_channels (int): The number of channels in the input space.
+ out_channels (int): The number of channels in the output space.
+ n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer.
+ resolutions (Union[int, list(int)]): The resolutions of the input tensor.
+ hidden_channels (int): The number of channels of the FNOBlock input and output. Default: ``20``.
+ lifting_channels (int): The number of channels of the lifting layer mid channels. Default: None.
+ projection_channels (int): The number of channels of the projection layer mid channels. Default: ``128``.
+ factor (int): The number of neurons in the hidden layer of a feedforward network. Default: ``1``.
+ n_layers (int): The number that Fourier Layer nests. Default: ``4``.
+ n_ff_layers (int): The number of layers (hidden layers) in the feedforward neural network. Default: ``2``.
+ ff_weight_norm (bool): Whether to do weight normalization in feedforward or not. Used as a reserved function
+ interface, the weight normalization is not supported in feedforward. Default: ``False``.
+ layer_norm (bool): Whether to do layer normalization in feedforward or not. Default: ``True``.
+ share_weight (bool): Whether to share weights between SpectralConv layers or not. Default: ``False``.
+ r_padding (int): The number used to pad a tensor on the right in a certain dimension. Default: ``0``.
+ data_format (str): The input data channel sequence. Default: ``channels_last``.
+ positional_embedding (bool): Whether to embed positional information or not. Default: ``True``.
+ dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``.
+ ffno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``.
+ Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for
+ the GPU backend, mstype.float16 is recommended for the Ascend backend.
+
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`.
+
+ Outputs:
+ Tensor, the output of this FNOBlocks.
+
+ - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, out\_channels)`.
+
+ Raises:
+ TypeError: If `in_channels` is not an int.
+ TypeError: If `out_channels` is not an int.
+ TypeError: If `hidden_channels` is not an int.
+ TypeError: If `lifting_channels` is not an int.
+ TypeError: If `projection_channels` is not an int.
+ TypeError: If `factor` is not an int.
+ TypeError: If `n_layers` is not an int.
+ TypeError: If `n_ff_layers` is not an int.
+ TypeError: If `ff_weight_norm` is not a Boolean value.
+ ValueError: If `ff_weight_norm` is not ``False``.
+ TypeError: If `layer_norm` is not a Boolean value.
+ TypeError: If `share_weight` is not a Boolean value.
+ TypeError: If `r_padding` is not an int.
+ TypeError: If `data_format` is not a str.
+ TypeError: If `positional_embedding` is not a bool.
+
+ Supported Platforms:
+ ``Ascend``
+
+ Examples:
+ >>> import numpy as np
+ >>> import mindspore
+ >>> import mindflow
+ >>> from mindspore import Tensor
+ >>> import mindspore.common.dtype as mstype
+ >>> from mindflow.cell import FFNO3D
+ >>> data = Tensor(np.ones([2, 128, 128, 128, 3]), mstype.float32)
+ >>> net = FFNO3D(in_channels=3, out_channels=3, n_modes=[20, 20, 20], resolutions=[128, 128, 128])
+ >>> out = net(data)
+ >>> print(data.shape, out.shape)
+ (2, 128, 128, 128, 3) (2, 128, 128, 128, 3)
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ n_modes,
+ resolutions,
+ hidden_channels=20,
+ lifting_channels=None,
+ projection_channels=128,
+ factor=1,
+ n_layers=4,
+ n_ff_layers=2,
+ ff_weight_norm=False,
+ layer_norm=True,
+ share_weight=False,
+ r_padding=0,
+ data_format="channels_last",
+ positional_embedding=True,
+ dft_compute_dtype=mstype.float32,
+ ffno_compute_dtype=mstype.float16
+ ):
+ n_modes, resolutions = validate_and_expand_dimensions(3, n_modes, resolutions)
+ super().__init__(
+ in_channels,
+ out_channels,
+ n_modes,
+ resolutions,
+ hidden_channels,
+ lifting_channels,
+ projection_channels,
+ factor,
+ n_layers,
+ n_ff_layers,
+ ff_weight_norm,
+ layer_norm,
+ share_weight,
+ r_padding,
+ data_format,
+ positional_embedding,
+ dft_compute_dtype,
+ ffno_compute_dtype
+ )
diff --git a/mindscience/models/neural_operator/ffno_sp.py b/mindscience/models/neural_operator/ffno_sp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5737112a88f461ca2d81df8e2f2c55bb605a7cee
--- /dev/null
+++ b/mindscience/models/neural_operator/ffno_sp.py
@@ -0,0 +1,465 @@
+''''
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+'''
+import mindspore as ms
+import mindspore.common.dtype as mstype
+from mindspore import nn, ops, Tensor, Parameter, ParameterTuple, mint
+from mindspore.common.initializer import XavierNormal, initializer
+from ...common.math import get_grid_1d, get_grid_2d, get_grid_3d
+from ...sciops.fourier import RDFTn, IRDFTn
+
+
+class FeedForward(nn.Cell):
+ """FeedForward cell"""
+
+ def __init__(self, dim, factor, ff_weight_norm, n_layers, layer_norm, dropout):
+ super().__init__()
+ self.layers = nn.CellList()
+ for i in range(n_layers):
+ in_dim = dim if i == 0 else dim * factor
+ out_dim = dim if i == n_layers - 1 else dim * factor
+ layer = nn.SequentialCell([
+ nn.Dense(in_dim, out_dim, has_bias=True) if not ff_weight_norm else nn.Identity(),
+ nn.Dropout(p=dropout),
+ nn.ReLU() if i < n_layers - 1 else nn.Identity(),
+ nn.LayerNorm((out_dim,), epsilon=1e-5) if layer_norm and i == n_layers - 1 else nn.Identity()])
+ self.layers.append(layer)
+
+ def construct(self, x):
+ for layer in self.layers:
+ x = layer(x)
+ return x
+
+
+class SpectralConv(nn.Cell):
+ """Base Class for Fourier Layer, including DFT, factorization, linear transform, and Inverse DFT"""
+
+ def __init__(self, in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff,
+ fourier_weight, factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, filter_mode,
+ compute_dtype=mstype.float32):
+ super().__init__()
+ self.einsum_flag = tuple([int(s) for s in ms.__version__.split('.')]) >= (2, 5, 0)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ if isinstance(n_modes, int):
+ n_modes = [n_modes]
+ self.n_modes = n_modes
+ if isinstance(resolutions, int):
+ resolutions = [resolutions]
+ self.resolutions = resolutions
+ if len(self.n_modes) != len(self.resolutions):
+ raise ValueError(
+ "The dimension of n_modes should be equal to that of resolutions, \
+ but got dimension of n_modes {} and dimension of resolutions {}".format(len(self.n_modes),
+ len(self.resolutions)))
+ self.compute_dtype = compute_dtype
+ self.use_fork = use_fork
+ self.fourier_weight = fourier_weight
+ self.filter_mode = filter_mode
+
+ if not self.fourier_weight:
+ param_list = []
+ for i, n_mode in enumerate(self.n_modes):
+ weight_re = Tensor(ops.ones((in_channels, out_channels, n_mode)), mstype.float32)
+ weight_im = Tensor(ops.ones((in_channels, out_channels, n_mode)), mstype.float32)
+
+ w_re = Parameter(initializer(XavierNormal(), weight_re.shape, mstype.float32), name=f'w_re_{i}',
+ requires_grad=True)
+ w_im = Parameter(initializer(XavierNormal(), weight_im.shape, mstype.float32), name=f'w_im_{i}',
+ requires_grad=True)
+
+ param_list.append(w_re)
+ param_list.append(w_im)
+
+ self.fourier_weight = ParameterTuple([param for param in param_list])
+
+ if use_fork:
+ self.forecast_ff = forecast_ff
+ if not self.forecast_ff:
+ self.forecast_ff = FeedForward(
+ out_channels, factor, ff_weight_norm, n_ff_layers, layer_norm, dropout)
+
+ self.backcast_ff = backcast_ff
+ if not self.backcast_ff:
+ self.backcast_ff = FeedForward(
+ out_channels, factor, ff_weight_norm, n_ff_layers, layer_norm, dropout)
+
+ self._positional_embedding, self._input_perm, self._output_perm = self._transpose(len(self.resolutions))
+
+ def construct(self, x: Tensor):
+ raise NotImplementedError()
+
+ def _fourier_dimension(self, n, mode, n_dim):
+ """" n- shape - 3D: S1 S2 S3 / 2D: M N / 1D: C
+ mode - output length - n//2 +1
+ dim - 3D: -1 -2 -3 / 2D: -1 -2 / 1D: -1 """
+ dft_cell = RDFTn(shape=n, dim=n_dim, norm='ortho', modes=mode, compute_dtype=self.compute_dtype)
+ idft_cell = IRDFTn(shape=n, dim=n_dim, norm='ortho', modes=mode, compute_dtype=self.compute_dtype)
+
+ return dft_cell, idft_cell
+
+ def _einsum(self, inputs, weights, dim):
+ """The Einstein multiplication function"""
+ res_len = len(self.resolutions)
+
+ if res_len not in [1, 2, 3]:
+ raise ValueError(
+ "The length of input resolutions dimensions should be in [1, 2, 3], but got: {}".format(res_len))
+
+ if self.einsum_flag:
+ expressions = {
+ ('x', 1): 'bix,iox->box',
+ ('x', 2): 'bixy,iox->boxy',
+ ('y', 2): 'bixy,ioy->boxy',
+ ('x', 3): 'bixyz,iox->boxyz',
+ ('y', 3): 'bixyz,ioy->boxyz',
+ ('z', 3): 'bixyz,ioz->boxyz'
+ }
+
+ key = (dim, res_len)
+ if key not in expressions:
+ raise ValueError(f"Unsupported type of the last dim of weight: {dim}")
+
+ out = mint.einsum(expressions[key], inputs, weights)
+
+ else:
+ _, weight_out, weight_dim = weights.shape
+ batch_size, inputs_in = inputs.shape[0], inputs.shape[1]
+ weights_perm = (2, 0, 1)
+
+ if res_len == 1:
+ if dim == 'x':
+ input_perm = (2, 0, 1)
+ output_perm = (1, 2, 0)
+ else:
+ raise ValueError(f"Unsupported type of the last dim of weight: {dim}")
+
+ inputs = ops.transpose(inputs, input_perm=input_perm)
+ weights = ops.transpose(weights, input_perm=weights_perm)
+ out = ops.bmm(inputs, weights)
+ out = ops.transpose(out, input_perm=output_perm)
+ elif res_len == 2:
+ if dim == 'y':
+ input_perm = (3, 0, 2, 1)
+ output_perm = (1, 3, 2, 0)
+ elif dim == 'x':
+ input_perm = (2, 0, 3, 1)
+ output_perm = (1, 3, 0, 2)
+ else:
+ raise ValueError(f"Unsupported type of the last dim of weight: {dim}")
+
+ inputs = ops.transpose(inputs, input_perm=input_perm)
+ inputs = ops.reshape(inputs, (weight_dim, -1, inputs_in))
+ weights = ops.transpose(weights, input_perm=weights_perm)
+ out = ops.bmm(inputs, weights)
+ out = ops.reshape(out, (weight_dim, batch_size, -1, weight_out))
+ out = ops.transpose(out, input_perm=output_perm)
+ else:
+ input_dim1, input_dim2, input_dim3 = inputs.shape[2], inputs.shape[3], inputs.shape[4]
+
+ if dim == 'z':
+ input_perm = (4, 0, 2, 3, 1)
+ output_perm = (1, 4, 2, 3, 0)
+ reshape_dim = input_dim1
+ elif dim == 'y':
+ input_perm = (3, 0, 4, 2, 1)
+ output_perm = (1, 4, 3, 0, 2)
+ reshape_dim = input_dim3
+ elif dim == 'x':
+ input_perm = (2, 0, 3, 4, 1)
+ output_perm = (1, 4, 0, 2, 3)
+ reshape_dim = input_dim2
+ else:
+ raise ValueError(f"Unsupported type of the last dim of weight: {dim}")
+
+ inputs = ops.transpose(inputs, input_perm=input_perm)
+ inputs = ops.reshape(inputs, (weight_dim, -1, inputs_in))
+ weights = ops.transpose(weights, input_perm=weights_perm)
+ out = ops.bmm(inputs, weights)
+ out = ops.reshape(out, (weight_dim, batch_size, reshape_dim, -1, weight_out))
+ out = ops.transpose(out, input_perm=output_perm)
+
+ return out
+
+ def _transpose(self, n_dim):
+ """transpose tensor"""
+ if n_dim == 1:
+ positional_embedding = Tensor(get_grid_1d(resolution=self.resolutions))
+ input_perm = (0, 2, 1)
+ output_perm = (0, 2, 1)
+ elif n_dim == 2:
+ positional_embedding = Tensor(get_grid_2d(resolution=self.resolutions))
+ input_perm = (0, 2, 3, 1)
+ output_perm = (0, 3, 1, 2)
+ elif n_dim == 3:
+ positional_embedding = Tensor(get_grid_3d(resolution=self.resolutions))
+ input_perm = (0, 2, 3, 4, 1)
+ output_perm = (0, 4, 1, 2, 3)
+ else:
+ raise ValueError(
+ "The length of input resolutions dimensions should be in [1, 2, 3], but got: {}".format(n_dim))
+ return positional_embedding, input_perm, output_perm
+
+ def _complex_mul(self, input_re, input_im, weight_re, weight_im, dim):
+ """(a + bj) * (c + dj) = (ac - bd) + (ad + bc)j"""
+ out_re = self._einsum(input_re, weight_re, dim) - self._einsum(input_im, weight_im, dim)
+ out_im = self._einsum(input_re, weight_im, dim) + self._einsum(input_im, weight_re, dim)
+
+ return out_re, out_im
+
+
+class SpectralConv1d(SpectralConv):
+ """1D Fourier layer. It does DFT, factorization, linear transform, and Inverse DFT."""
+
+ def __init__(self, in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff,
+ fourier_weight, factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, r_padding,
+ filter_mode, compute_dtype=mstype.float32):
+ super().__init__(in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff, fourier_weight,
+ factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, filter_mode)
+
+ self._dft1_x_cell, self._idft1_x_cell = self._fourier_dimension(resolutions[0] + r_padding, n_modes[0], -1)
+
+ def construct(self, x: Tensor):
+ x = self.construct_fourier(x)
+ b = self.backcast_ff(x)
+ f = self.forecast_ff(x) if self.use_fork else None
+
+ return b, f
+
+ def construct_fourier(self, x):
+ """1D Fourier layer."""
+ x = ops.transpose(x, input_perm=self._output_perm) # x shape: batch, in_dim, grid_size
+
+ x_ft_re = x
+
+ x_ftx_re, x_ftx_im = self._dft1_x_cell(x_ft_re)
+
+ x_ftx_re_part = x_ftx_re[:, :, :self.n_modes[0]]
+ x_ftx_im_part = x_ftx_im[:, :, :self.n_modes[0]]
+
+ re0, re1, re2 = x_ftx_re.shape
+ im0, im1, im2 = x_ftx_im.shape
+ out_ftx_remain_re = ops.zeros((re0, re1, re2 - self.n_modes[0]))
+ out_ftx_remain_im = ops.zeros((im0, im1, im2 - self.n_modes[0]))
+
+ if self.filter_mode == 'full':
+ ftx_re, ftx_im = self._complex_mul(
+ x_ftx_re_part, x_ftx_im_part, self.fourier_weight[0], self.fourier_weight[1], 'x')
+ out_ftx_re = ops.cat([ftx_re, out_ftx_remain_re], axis=2)
+ out_ftx_im = ops.cat([ftx_im, out_ftx_remain_im], axis=2)
+ elif self.filter_mode == 'low_pass':
+ out_ftx_re = ops.cat([x_ftx_re_part, out_ftx_remain_re], axis=2)
+ out_ftx_im = ops.cat([x_ftx_im_part, out_ftx_remain_im], axis=2)
+ else:
+ out_ftx_re = ops.zeros_like(x_ftx_re)
+ out_ftx_im = ops.zeros_like(x_ftx_im)
+
+ x = self._idft1_x_cell(out_ftx_re, out_ftx_im)
+ x = ops.transpose(x, input_perm=self._input_perm)
+
+ return x
+
+
+class SpectralConv2d(SpectralConv):
+ """2D Fourier layer. It does DFT, factorization, linear transform, and Inverse DFT."""
+
+ def __init__(self, in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff,
+ fourier_weight, factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, r_padding,
+ filter_mode, compute_dtype=mstype.float32):
+ super().__init__(in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff, fourier_weight,
+ factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, filter_mode)
+
+ self._dft1_x_cell, self._idft1_x_cell = self._fourier_dimension(resolutions[0] + r_padding, n_modes[0], -2)
+ self._dft1_y_cell, self._idft1_y_cell = self._fourier_dimension(resolutions[1] + r_padding, n_modes[1], -1)
+
+ def construct(self, x: Tensor):
+ x = self.construct_fourier(x)
+ b = self.backcast_ff(x)
+ f = self.forecast_ff(x) if self.use_fork else None
+
+ return b, f
+
+ def construct_fourier(self, x):
+ """2D Fourier layer."""
+ x = ops.transpose(x, input_perm=self._output_perm) # x shape: batch, in_dim, grid_size, grid_size
+
+ x_ft_re = x
+
+ # Dimesion Y
+ x_fty_re, x_fty_im = self._dft1_y_cell(x_ft_re)
+
+ x_fty_re_part = x_fty_re[:, :, :, :self.n_modes[1]]
+ x_fty_im_part = x_fty_im[:, :, :, :self.n_modes[1]]
+
+ re0, re1, re2, re3 = x_fty_re.shape
+ im0, im1, im2, im3 = x_fty_im.shape
+ out_fty_remain_re = ops.zeros((re0, re1, re2, re3 - self.n_modes[1]))
+ out_fty_remain_im = ops.zeros((im0, im1, im2, im3 - self.n_modes[1]))
+
+ if self.filter_mode == 'full':
+ fty_re, fty_im = self._complex_mul(
+ x_fty_re_part, x_fty_im_part, self.fourier_weight[2], self.fourier_weight[3], 'y')
+ out_fty_re = ops.cat([fty_re, out_fty_remain_re], axis=3)
+ out_fty_im = ops.cat([fty_im, out_fty_remain_im], axis=3)
+ elif self.filter_mode == 'low_pass':
+ out_fty_re = ops.cat([x_fty_re_part, out_fty_remain_re], axis=3)
+ out_fty_im = ops.cat([x_fty_im_part, out_fty_remain_im], axis=3)
+ else:
+ out_fty_re = ops.zeros_like(x_fty_re)
+ out_fty_im = ops.zeros_like(x_fty_im)
+
+ xy = self._idft1_y_cell(out_fty_re, out_fty_im)
+
+ # Dimesion X
+ x_ftx_re, x_ftx_im = self._dft1_x_cell(x_ft_re)
+
+ x_ftx_re_part = x_ftx_re[:, :, :self.n_modes[0], :]
+ x_ftx_im_part = x_ftx_im[:, :, :self.n_modes[0], :]
+
+ re0, re1, re2, re3 = x_ftx_re.shape
+ im0, im1, im2, im3 = x_ftx_im.shape
+ out_ftx_remain_re = ops.zeros((re0, re1, re2 - self.n_modes[0], re3))
+ out_ftx_remain_im = ops.zeros((im0, im1, im2 - self.n_modes[0], im3))
+
+ if self.filter_mode == 'full':
+ ftx_re, ftx_im = self._complex_mul(
+ x_ftx_re_part, x_ftx_im_part, self.fourier_weight[0], self.fourier_weight[1], 'x')
+ out_ftx_re = ops.cat([ftx_re, out_ftx_remain_re], axis=2)
+ out_ftx_im = ops.cat([ftx_im, out_ftx_remain_im], axis=2)
+ elif self.filter_mode == 'low_pass':
+ out_ftx_re = ops.cat([x_ftx_re_part, out_ftx_remain_re], axis=2)
+ out_ftx_im = ops.cat([x_ftx_im_part, out_ftx_remain_im], axis=2)
+ else:
+ out_ftx_re = ops.zeros_like(x_ftx_re)
+ out_ftx_im = ops.zeros_like(x_ftx_im)
+
+ xx = self._idft1_x_cell(out_ftx_re, out_ftx_im)
+
+ # Combining Dimensions
+ x = xx + xy
+
+ x = ops.transpose(x, input_perm=self._input_perm)
+
+ return x
+
+
+class SpectralConv3d(SpectralConv):
+ """3D Fourier layer. It does DFT, factorization, linear transform, and Inverse DFT."""
+
+ def __init__(self, in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff,
+ fourier_weight, factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, r_padding,
+ filter_mode, compute_dtype=mstype.float32):
+ super().__init__(in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff, fourier_weight,
+ factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, filter_mode)
+
+ self._dft1_x_cell, self._idft1_x_cell = self._fourier_dimension(resolutions[0] + r_padding, n_modes[0], -3)
+ self._dft1_y_cell, self._idft1_y_cell = self._fourier_dimension(resolutions[1] + r_padding, n_modes[1], -2)
+ self._dft1_z_cell, self._idft1_z_cell = self._fourier_dimension(resolutions[2] + r_padding, n_modes[2], -1)
+
+ def construct(self, x: Tensor):
+ x = self.construct_fourier(x)
+ b = self.backcast_ff(x)
+ f = self.forecast_ff(x) if self.use_fork else None
+
+ return b, f
+
+ def construct_fourier(self, x):
+ """3D Fourier layer."""
+ x = ops.transpose(x, input_perm=self._output_perm) # x shape: batch, in_dim, grid_size, grid_size, grid_size
+
+ x_ft_re = x
+
+ # Dimesion Z
+ x_ftz_re, x_ftz_im = self._dft1_z_cell(x_ft_re)
+
+ x_ftz_re_part = x_ftz_re[:, :, :, :, :self.n_modes[2]]
+ x_ftz_im_part = x_ftz_im[:, :, :, :, :self.n_modes[2]]
+
+ re0, re1, re2, re3, re4 = x_ftz_re.shape
+ im0, im1, im2, im3, im4 = x_ftz_im.shape
+ out_ftz_remain_re = ops.zeros((re0, re1, re2, re3, re4 - self.n_modes[2]))
+ out_ftz_remain_im = ops.zeros((im0, im1, im2, im3, im4 - self.n_modes[2]))
+
+ if self.filter_mode == 'full':
+ ftz_re, ftz_im = self._complex_mul(
+ x_ftz_re_part, x_ftz_im_part, self.fourier_weight[4], self.fourier_weight[5], 'z')
+ out_ftz_re = ops.cat([ftz_re, out_ftz_remain_re], axis=4)
+ out_ftz_im = ops.cat([ftz_im, out_ftz_remain_im], axis=4)
+ elif self.filter_mode == 'low_pass':
+ out_ftz_re = ops.cat([x_ftz_re_part, out_ftz_remain_re], axis=4)
+ out_ftz_im = ops.cat([x_ftz_im_part, out_ftz_remain_im], axis=4)
+ else:
+ out_ftz_re = ops.zeros_like(x_ftz_re)
+ out_ftz_im = ops.zeros_like(x_ftz_im)
+
+ xz = self._idft1_z_cell(out_ftz_re, out_ftz_im)
+
+ # Dimesion Y
+ x_fty_re, x_fty_im = self._dft1_y_cell(x_ft_re)
+
+ x_fty_re_part = x_fty_re[:, :, :, :self.n_modes[1], :]
+ x_fty_im_part = x_fty_im[:, :, :, :self.n_modes[1], :]
+
+ re0, re1, re2, re3, re4 = x_fty_re.shape
+ im0, im1, im2, im3, im4 = x_fty_im.shape
+ out_fty_remain_re = ops.zeros((re0, re1, re2, re3 - self.n_modes[1], re4))
+ out_fty_remain_im = ops.zeros((im0, im1, im2, im3 - self.n_modes[1], im4))
+
+ if self.filter_mode == 'full':
+ fty_re, fty_im = self._complex_mul(
+ x_fty_re_part, x_fty_im_part, self.fourier_weight[2], self.fourier_weight[3], 'y')
+ out_fty_re = ops.cat([fty_re, out_fty_remain_re], axis=3)
+ out_fty_im = ops.cat([fty_im, out_fty_remain_im], axis=3)
+ elif self.filter_mode == 'low_pass':
+ out_fty_re = ops.cat([x_fty_re_part, out_fty_remain_re], axis=3)
+ out_fty_im = ops.cat([x_fty_im_part, out_fty_remain_im], axis=3)
+ else:
+ out_fty_re = ops.zeros_like(x_fty_re)
+ out_fty_im = ops.zeros_like(x_fty_im)
+
+ xy = self._idft1_y_cell(out_fty_re, out_fty_im)
+
+ # Dimesion X
+ x_ftx_re, x_ftx_im = self._dft1_x_cell(x_ft_re)
+
+ x_ftx_re_part = x_ftx_re[:, :, :self.n_modes[0], :, :]
+ x_ftx_im_part = x_ftx_im[:, :, :self.n_modes[0], :, :]
+
+ re0, re1, re2, re3, re4 = x_ftx_re.shape
+ im0, im1, im2, im3, im4 = x_ftx_im.shape
+ out_ftx_remain_re = ops.zeros((re0, re1, re2 - self.n_modes[0], re3, re4))
+ out_ftx_remain_im = ops.zeros((im0, im1, im2 - self.n_modes[0], im3, im4))
+
+ if self.filter_mode == 'full':
+ ftx_re, ftx_im = self._complex_mul(
+ x_ftx_re_part, x_ftx_im_part, self.fourier_weight[0], self.fourier_weight[1], 'x')
+ out_ftx_re = ops.cat([ftx_re, out_ftx_remain_re], axis=2)
+ out_ftx_im = ops.cat([ftx_im, out_ftx_remain_im], axis=2)
+ elif self.filter_mode == 'low_pass':
+ out_ftx_re = ops.cat([x_ftx_re_part, out_ftx_remain_re], axis=2)
+ out_ftx_im = ops.cat([x_ftx_im_part, out_ftx_remain_im], axis=2)
+ else:
+ out_ftx_re = ops.zeros_like(x_ftx_re)
+ out_ftx_im = ops.zeros_like(x_ftx_im)
+
+ xx = self._idft1_x_cell(out_ftx_re, out_ftx_im)
+
+ # Combining Dimensions
+ x = xx + xy + xz
+
+ x = ops.transpose(x, input_perm=self._input_perm)
+
+ return x
diff --git a/mindscience/models/neural_operator/fno.py b/mindscience/models/neural_operator/fno.py
index f013c6b32949f1a792dd56f605be0cfc9e578e58..8050ad03c238530d31f4d4d2f0467d9ace63e2f0 100644
--- a/mindscience/models/neural_operator/fno.py
+++ b/mindscience/models/neural_operator/fno.py
@@ -16,12 +16,12 @@
'''
# pylint: disable=W0235
-from mindspore import nn, ops, Tensor
+from mindspore import nn, ops, Tensor, mint
import mindspore.common.dtype as mstype
-from .dft import SpectralConv1dDft, SpectralConv2dDft, SpectralConv3dDft
+from .fno_sp import SpectralConv1dDft, SpectralConv2dDft, SpectralConv3dDft
from ..activation import get_activation
-from ...core.math import get_grid_1d, get_grid_2d, get_grid_3d
+from ...common.math import get_grid_1d, get_grid_2d, get_grid_3d
from ...utils.check_func import check_param_type
@@ -306,7 +306,7 @@ class FNO(nn.Cell):
def construct(self, x: Tensor):
"""construct"""
batch_size = x.shape[0]
- grid = self._positional_embedding.repeat(batch_size, axis=0).astype(x.dtype)
+ grid = mint.repeat_interleave(self._positional_embedding.astype(x.dtype), batch_size, dim=0)
if self.data_format != "channels_last":
x = ops.transpose(x, input_perm=self._output_perm)
if self.positional_embedding:
diff --git a/mindscience/models/neural_operator/fno_sp.py b/mindscience/models/neural_operator/fno_sp.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b2e4e8baa0b742fab1c003d27d15cc84373b050
--- /dev/null
+++ b/mindscience/models/neural_operator/fno_sp.py
@@ -0,0 +1,243 @@
+''''
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+'''
+import numpy as np
+
+import mindspore.common.dtype as mstype
+from mindspore import nn, ops, Tensor, Parameter, mint
+from mindspore.common.initializer import Zero
+from mindspore.ops import operations as P
+
+from ...sciops.fourier import RDFTn, IRDFTn
+
+
+class SpectralConvDft(nn.Cell):
+ """Base Class for Fourier Layer, including DFT, linear transform, and Inverse DFT"""
+
+ def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ if isinstance(n_modes, int):
+ n_modes = [n_modes]
+ self.n_modes = n_modes
+ if isinstance(resolutions, int):
+ resolutions = [resolutions]
+ self.resolutions = resolutions
+ if len(self.n_modes) != len(self.resolutions):
+ raise ValueError(
+ "The dimension of n_modes should be equal to that of resolutions, \
+ but got dimension of n_modes {} and dimension of resolutions {}".format(len(self.n_modes),
+ len(self.resolutions)))
+ self.compute_dtype = compute_dtype
+
+ def construct(self, x: Tensor):
+ raise NotImplementedError()
+
+ def _einsum(self, inputs, weights):
+ weights = weights.expand_dims(0)
+ inputs = inputs.expand_dims(2)
+ out = inputs * weights
+ return out.sum(1)
+
+
+class SpectralConv1dDft(SpectralConvDft):
+ """1D Fourier Layer. It does DFT, linear transform, and Inverse DFT."""
+
+ def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32):
+ super().__init__(in_channels, out_channels, n_modes, resolutions)
+ self._scale = (1. / (self.in_channels * self.out_channels))
+ w_re = Tensor(self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0]),
+ dtype=mstype.float32)
+ w_im = Tensor(self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0]),
+ dtype=mstype.float32)
+ self._w_re = Parameter(w_re, requires_grad=True)
+ self._w_im = Parameter(w_im, requires_grad=True)
+ self._dft1_cell = RDFTn(
+ shape=(self.resolutions[0],), norm='ortho', modes=self.n_modes[0], compute_dtype=self.compute_dtype)
+ self._idft1_cell = IRDFTn(
+ shape=(self.resolutions[0],), norm='ortho', modes=self.n_modes[0], compute_dtype=self.compute_dtype)
+
+ def construct(self, x: Tensor):
+ x_re = x
+ x_ft_re, x_ft_im = self._dft1_cell(x_re)
+ w_re = P.Cast()(self._w_re, self.compute_dtype)
+ w_im = P.Cast()(self._w_im, self.compute_dtype)
+ out_ft_re = self._einsum(x_ft_re[:, :, :self.n_modes[0]], w_re) - self._einsum(x_ft_im[:, :, :self.n_modes[0]],
+ w_im)
+ out_ft_im = self._einsum(x_ft_re[:, :, :self.n_modes[0]], w_im) + self._einsum(x_ft_im[:, :, :self.n_modes[0]],
+ w_re)
+
+ x = self._idft1_cell(out_ft_re, out_ft_im)
+
+ return x
+
+
+class SpectralConv2dDft(SpectralConvDft):
+ """2D Fourier Layer. It does DFT, linear transform, and Inverse DFT."""
+
+ def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32):
+ super().__init__(in_channels, out_channels, n_modes, resolutions)
+ self._scale = (1. / (self.in_channels * self.out_channels))
+ w_re1 = Tensor(
+ self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]),
+ dtype=self.compute_dtype)
+ w_im1 = Tensor(
+ self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]),
+ dtype=self.compute_dtype)
+ w_re2 = Tensor(
+ self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]),
+ dtype=self.compute_dtype)
+ w_im2 = Tensor(
+ self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]),
+ dtype=self.compute_dtype)
+
+ self._w_re1 = Parameter(w_re1, requires_grad=True)
+ self._w_im1 = Parameter(w_im1, requires_grad=True)
+ self._w_re2 = Parameter(w_re2, requires_grad=True)
+ self._w_im2 = Parameter(w_im2, requires_grad=True)
+
+ self._dft2_cell = RDFTn(shape=(self.resolutions[0], self.resolutions[1]), norm='ortho',
+ modes=(self.n_modes[0], self.n_modes[1]), compute_dtype=self.compute_dtype)
+ self._idft2_cell = IRDFTn(shape=(self.resolutions[0], self.resolutions[1]), norm='ortho',
+ modes=(self.n_modes[0], self.n_modes[1]), compute_dtype=self.compute_dtype)
+ self._mat = Tensor(shape=(1, self.out_channels, self.resolutions[1] - 2 * self.n_modes[0], self.n_modes[1]),
+ dtype=self.compute_dtype, init=Zero())
+ self._concat = ops.Concat(-2)
+
+ def construct(self, x: Tensor):
+ x_re = x
+ x_ft_re, x_ft_im = self._dft2_cell(x_re)
+
+ out_ft_re1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1) - self._einsum(
+ x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1)
+ out_ft_im1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1) + self._einsum(
+ x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1)
+
+ out_ft_re2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2) - self._einsum(
+ x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2)
+ out_ft_im2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2) + self._einsum(
+ x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2)
+
+ batch_size = x.shape[0]
+ mat = mint.repeat_interleave(self._mat, batch_size, 0)
+ out_re = self._concat((out_ft_re1, mat, out_ft_re2))
+ out_im = self._concat((out_ft_im1, mat, out_ft_im2))
+
+ x = self._idft2_cell(out_re, out_im)
+
+ return x
+
+
+class SpectralConv3dDft(SpectralConvDft):
+ """3D Fourier layer. It does DFT, linear transform, and Inverse DFT."""
+
+ def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32):
+ super().__init__(in_channels, out_channels, n_modes, resolutions)
+ self._scale = (1 / (self.in_channels * self.out_channels))
+
+ w_re1 = Tensor(
+ self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1],
+ self.n_modes[2]), dtype=self.compute_dtype)
+ w_im1 = Tensor(
+ self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1],
+ self.n_modes[2]), dtype=self.compute_dtype)
+ w_re2 = Tensor(
+ self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1],
+ self.n_modes[2]), dtype=self.compute_dtype)
+ w_im2 = Tensor(
+ self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1],
+ self.n_modes[2]), dtype=self.compute_dtype)
+ w_re3 = Tensor(
+ self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1],
+ self.n_modes[2]), dtype=self.compute_dtype)
+ w_im3 = Tensor(
+ self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1],
+ self.n_modes[2]), dtype=self.compute_dtype)
+ w_re4 = Tensor(
+ self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1],
+ self.n_modes[2]), dtype=self.compute_dtype)
+ w_im4 = Tensor(
+ self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1],
+ self.n_modes[2]), dtype=self.compute_dtype)
+
+ self._w_re1 = Parameter(w_re1, requires_grad=True)
+ self._w_im1 = Parameter(w_im1, requires_grad=True)
+ self._w_re2 = Parameter(w_re2, requires_grad=True)
+ self._w_im2 = Parameter(w_im2, requires_grad=True)
+ self._w_re3 = Parameter(w_re3, requires_grad=True)
+ self._w_im3 = Parameter(w_im3, requires_grad=True)
+ self._w_re4 = Parameter(w_re4, requires_grad=True)
+ self._w_im4 = Parameter(w_im4, requires_grad=True)
+
+ self._dft3_cell = RDFTn(shape=(self.resolutions[0], self.resolutions[1], self.resolutions[2]), norm='ortho',
+ modes=(self.n_modes[0], self.n_modes[1], self.n_modes[2]),
+ compute_dtype=self.compute_dtype)
+ self._idft3_cell = IRDFTn(shape=(self.resolutions[0], self.resolutions[1], self.resolutions[2]), norm='ortho',
+ modes=(self.n_modes[0], self.n_modes[1], self.n_modes[2]),
+ compute_dtype=self.compute_dtype)
+ self._mat_x = Tensor(
+ shape=(1, self.out_channels, self.resolutions[0] - 2 * self.n_modes[0], self.n_modes[1], self.n_modes[2]),
+ dtype=self.compute_dtype, init=Zero())
+ self._mat_y = Tensor(
+ shape=(1, self.out_channels, self.resolutions[0], self.resolutions[1] - 2 * self.n_modes[1],
+ self.n_modes[2]),
+ dtype=self.compute_dtype, init=Zero())
+ self._concat = ops.Concat(-2)
+
+ def construct(self, x: Tensor):
+ x_re = x
+ x_ft_re, x_ft_im = self._dft3_cell(x_re)
+
+ out_ft_re1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]],
+ self._w_re1) - self._einsum(x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1],
+ :self.n_modes[2]], self._w_im1)
+ out_ft_im1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]],
+ self._w_im1) + self._einsum(x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1],
+ :self.n_modes[2]], self._w_re1)
+ out_ft_re2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]],
+ self._w_re2) - self._einsum(x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1],
+ :self.n_modes[2]], self._w_im2)
+ out_ft_im2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]],
+ self._w_im2) + self._einsum(x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1],
+ :self.n_modes[2]], self._w_re2)
+ out_ft_re3 = self._einsum(x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]],
+ self._w_re3) - self._einsum(x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:,
+ :self.n_modes[2]], self._w_im3)
+ out_ft_im3 = self._einsum(x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]],
+ self._w_im3) + self._einsum(x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:,
+ :self.n_modes[2]], self._w_re3)
+ out_ft_re4 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]],
+ self._w_re4) - self._einsum(x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:,
+ :self.n_modes[2]], self._w_im4)
+ out_ft_im4 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]],
+ self._w_im4) + self._einsum(x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:,
+ :self.n_modes[2]], self._w_re4)
+
+ batch_size = x.shape[0]
+ mat_x = mint.repeat_interleave(self._mat_x, batch_size, 0)
+ mat_y = mint.repeat_interleave(self._mat_y, batch_size, 0)
+
+ out_re1 = ops.concat((out_ft_re1, mat_x, out_ft_re2), -3)
+ out_im1 = ops.concat((out_ft_im1, mat_x, out_ft_im2), -3)
+
+ out_re2 = ops.concat((out_ft_re3, mat_x, out_ft_re4), -3)
+ out_im2 = ops.concat((out_ft_im3, mat_x, out_ft_im4), -3)
+ out_re = ops.concat((out_re1, mat_y, out_re2), -2)
+ out_im = ops.concat((out_im1, mat_y, out_im2), -2)
+ x = self._idft3_cell(out_re, out_im)
+
+ return x
diff --git a/mindscience/models/neural_operator/kno1d.py b/mindscience/models/neural_operator/kno1d.py
index b4c11bb5ddfcd5bd3d698281b2f234adc110f344..81820de728a5a6d5563e04c6819e28eb6db5bf1b 100644
--- a/mindscience/models/neural_operator/kno1d.py
+++ b/mindscience/models/neural_operator/kno1d.py
@@ -16,7 +16,7 @@
import mindspore.common.dtype as mstype
from mindspore import ops, nn, Tensor
-from .dft import SpectralConv1dDft
+from .fno_sp import SpectralConv1dDft
from ...utils.check_func import check_param_type
diff --git a/mindscience/models/neural_operator/kno2d.py b/mindscience/models/neural_operator/kno2d.py
index 9674709f952725070d9c333b06055652342596cd..79f9ae98a2a824cb2339287ced8f7c66e3655af5 100644
--- a/mindscience/models/neural_operator/kno2d.py
+++ b/mindscience/models/neural_operator/kno2d.py
@@ -1,119 +1,119 @@
-# Copyright 2023 Huawei Technologies Co., Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-"""KNO2D"""
-import mindspore.common.dtype as mstype
-from mindspore import ops, nn, Tensor
-
-from .dft import SpectralConv2dDft
-from ...utils.check_func import check_param_type
-
-
-class KNO2D(nn.Cell):
- r"""
- The 2-dimensional Koopman Neural Operator (KNO2D) contains a encoder layer and a decoder layer,
- multiple Koopman layers.
- The details can be found in `KoopmanLab: machine learning for solving complex physics equations
- `_.
-
- Args:
- in_channels (int): The number of channels in the input space. Default: ``1``.
- channels (int): The number of channels after dimension lifting of the input. Default: ``32``.
- modes (int): The number of low-frequency components to keep. Default: ``16``.
- resolution (int): The spatial resolution of the input. Default: ``1024``.
- depths (int): The number of KNO layers. Default: ``4``.
- compute_dtype (dtype.Number): The computation type of dense. Default: ``mstype.float16``.
- Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for
- the GPU backend, mstype.float16 is recommended for the Ascend backend.
-
- Inputs:
- - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`.
-
- Outputs:
- Tensor, the output of this KNO network.
-
- - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, in\_channels)`.
-
- Raises:
- TypeError: If `in_channels` is not an int.
- TypeError: If `channels` is not an int.
- TypeError: If `modes` is not an int.
- TypeError: If `depths` is not an int.
- TypeError: If `resolution` is not an int.
-
- Supported Platforms:
- ``Ascend`` ``GPU``
-
- Examples:
- >>> import numpy as np
- >>> from mindflow.cell.neural_operators import KNO2D
- >>> input_ = Tensor(np.ones([32, 64, 64, 10]), mstype.float32)
- >>> net = KNO2D()
- >>> x, x_reconstruct = net(input_)
- >>> print(x.shape, x_reconstruct.shape)
- (32, 64, 64, 10) (32, 64, 64, 10)
- """
-
- def __init__(self,
- in_channels=10,
- channels=32,
- modes=16,
- depths=4,
- resolution=64,
- compute_dtype=mstype.float32):
- super().__init__()
- check_param_type(in_channels, "in_channels",
- data_type=int, exclude_type=bool)
- check_param_type(channels, "channels",
- data_type=int, exclude_type=bool)
- check_param_type(modes, "modes",
- data_type=int, exclude_type=bool)
- check_param_type(depths, "depths",
- data_type=int, exclude_type=bool)
- check_param_type(resolution, "resolution",
- data_type=int, exclude_type=bool)
- self.in_channels = in_channels
- self.channels = channels
- self.modes = modes
- self.depths = depths
- self.resolution = resolution
- self.enc = nn.Dense(in_channels, channels, has_bias=True)
- self.dec = nn.Dense(channels, in_channels, has_bias=True)
- self.koopman_layer = SpectralConv2dDft(channels, channels, [modes, modes], [resolution, resolution],
- compute_dtype=compute_dtype)
- self.w0 = nn.Conv2d(channels, channels, 1, has_bias=True)
-
- def construct(self, x: Tensor):
- """KNO2D forward function.
-
- Args:
- x (Tensor): Input Tensor.
- """
- # reconstruct
- x_reconstruct = self.enc(x)
- x_reconstruct = ops.tanh(x_reconstruct)
- x_reconstruct = self.dec(x_reconstruct)
-
- # predict
- x = self.enc(x)
- x = ops.tanh(x)
- x = x.transpose(0, 3, 1, 2)
- x_w = x
- for _ in range(self.depths):
- x1 = self.koopman_layer(x)
- x = ops.tanh(x + x1)
- x = ops.tanh(self.w0(x_w) + x)
- x = x.transpose(0, 2, 3, 1)
- x = self.dec(x)
- return x, x_reconstruct
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""KNO2D"""
+import mindspore.common.dtype as mstype
+from mindspore import ops, nn, Tensor
+
+from .fno_sp import SpectralConv2dDft
+from ...utils.check_func import check_param_type
+
+
+class KNO2D(nn.Cell):
+ r"""
+ The 2-dimensional Koopman Neural Operator (KNO2D) contains a encoder layer and a decoder layer,
+ multiple Koopman layers.
+ The details can be found in `KoopmanLab: machine learning for solving complex physics equations
+ `_.
+
+ Args:
+ in_channels (int): The number of channels in the input space. Default: ``1``.
+ channels (int): The number of channels after dimension lifting of the input. Default: ``32``.
+ modes (int): The number of low-frequency components to keep. Default: ``16``.
+ resolution (int): The spatial resolution of the input. Default: ``1024``.
+ depths (int): The number of KNO layers. Default: ``4``.
+ compute_dtype (dtype.Number): The computation type of dense. Default: ``mstype.float16``.
+ Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for
+ the GPU backend, mstype.float16 is recommended for the Ascend backend.
+
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`.
+
+ Outputs:
+ Tensor, the output of this KNO network.
+
+ - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, in\_channels)`.
+
+ Raises:
+ TypeError: If `in_channels` is not an int.
+ TypeError: If `channels` is not an int.
+ TypeError: If `modes` is not an int.
+ TypeError: If `depths` is not an int.
+ TypeError: If `resolution` is not an int.
+
+ Supported Platforms:
+ ``Ascend`` ``GPU``
+
+ Examples:
+ >>> import numpy as np
+ >>> from mindflow.cell.neural_operators import KNO2D
+ >>> input_ = Tensor(np.ones([32, 64, 64, 10]), mstype.float32)
+ >>> net = KNO2D()
+ >>> x, x_reconstruct = net(input_)
+ >>> print(x.shape, x_reconstruct.shape)
+ (32, 64, 64, 10) (32, 64, 64, 10)
+ """
+
+ def __init__(self,
+ in_channels=10,
+ channels=32,
+ modes=16,
+ depths=4,
+ resolution=64,
+ compute_dtype=mstype.float32):
+ super().__init__()
+ check_param_type(in_channels, "in_channels",
+ data_type=int, exclude_type=bool)
+ check_param_type(channels, "channels",
+ data_type=int, exclude_type=bool)
+ check_param_type(modes, "modes",
+ data_type=int, exclude_type=bool)
+ check_param_type(depths, "depths",
+ data_type=int, exclude_type=bool)
+ check_param_type(resolution, "resolution",
+ data_type=int, exclude_type=bool)
+ self.in_channels = in_channels
+ self.channels = channels
+ self.modes = modes
+ self.depths = depths
+ self.resolution = resolution
+ self.enc = nn.Dense(in_channels, channels, has_bias=True)
+ self.dec = nn.Dense(channels, in_channels, has_bias=True)
+ self.koopman_layer = SpectralConv2dDft(channels, channels, [modes, modes], [resolution, resolution],
+ compute_dtype=compute_dtype)
+ self.w0 = nn.Conv2d(channels, channels, 1, has_bias=True)
+
+ def construct(self, x: Tensor):
+ """KNO2D forward function.
+
+ Args:
+ x (Tensor): Input Tensor.
+ """
+ # reconstruct
+ x_reconstruct = self.enc(x)
+ x_reconstruct = ops.tanh(x_reconstruct)
+ x_reconstruct = self.dec(x_reconstruct)
+
+ # predict
+ x = self.enc(x)
+ x = ops.tanh(x)
+ x = x.transpose(0, 3, 1, 2)
+ x_w = x
+ for _ in range(self.depths):
+ x1 = self.koopman_layer(x)
+ x = ops.tanh(x + x1)
+ x = ops.tanh(self.w0(x_w) + x)
+ x = x.transpose(0, 2, 3, 1)
+ x = self.dec(x)
+ return x, x_reconstruct
diff --git a/mindscience/models/neural_operator/sno.py b/mindscience/models/neural_operator/sno.py
index 335bd711febbfce1864f2650e2b97cf3b7251506..7f66352bccdf88cba55e69ba77741f73615ab1be 100644
--- a/mindscience/models/neural_operator/sno.py
+++ b/mindscience/models/neural_operator/sno.py
@@ -17,8 +17,7 @@ import mindspore.common.dtype as mstype
from mindspore import ops, nn
from .sp_transform import ConvCell, TransformCell, Dim
-from ..activation import get_activation
-from ..unet2d import UNet2D
+from ..layers import get_activation, UNet2D
from ...utils.check_func import check_param_type, check_param_type_value
diff --git a/mindscience/models/transformer/attention.py b/mindscience/models/transformer/attention.py
index ce4460809412ccc9301faf7d919b34a997675f1e..d19fff9cb9548e2510670ae7f718a94f5ea73bde 100644
--- a/mindscience/models/transformer/attention.py
+++ b/mindscience/models/transformer/attention.py
@@ -17,7 +17,7 @@ from typing import Optional
from mindspore import ops, nn, Tensor
import mindspore.common.dtype as mstype
-from .basic_block import DropPath
+from ..layers import DropPath
class Attention(nn.Cell):
diff --git a/mindscience/models/transformer/vit.py b/mindscience/models/transformer/vit.py
index 5fcf6fa6e1d2a3539024e175a5aa2d27a449f9cc..9923a668ce8f321c0c755ac2652290692e91a03f 100644
--- a/mindscience/models/transformer/vit.py
+++ b/mindscience/models/transformer/vit.py
@@ -21,8 +21,8 @@ import mindspore.ops.operations as P
from mindspore.common.initializer import initializer, XavierUniform
import mindspore.common.dtype as mstype
-from .utils import to_2tuple, get_2d_sin_cos_pos_embed
-from .attention import AttentionBlock
+from ...common import to_2tuple, get_2d_sin_cos_pos_embed
+from .attention import TransformerBlock
class PatchEmbedding(nn.Cell):
@@ -132,7 +132,7 @@ class VitEncoder(nn.Cell):
mstype.float32
)
for _ in range(depths):
- layer = AttentionBlock(
+ layer = TransformerBlock(
in_channels=hidden_channels,
num_heads=num_heads,
dropout_rate=dropout_rate,
@@ -212,7 +212,7 @@ class VitDecoder(nn.Cell):
mstype.float32
)
for _ in range(depths):
- layer = AttentionBlock(
+ layer = TransformerBlock(
in_channels=hidden_channels,
num_heads=num_heads,
dropout_rate=dropout_rate,
diff --git a/mindscience/pde/__init__.py b/mindscience/pde/__init__.py
index 7e16260f0cf8177888f9d20182d8b1b118b2f209..736f8b2c7feb3c21dbe1dfc0d96a5ebe93581933 100644
--- a/mindscience/pde/__init__.py
+++ b/mindscience/pde/__init__.py
@@ -19,5 +19,3 @@ from .flow_with_loss import FlowWithLoss, SteadyFlowWithLoss, UnsteadyFlowWithLo
__all__ = ["Burgers", "NavierStokes", "Poisson", "sympy_to_mindspore", "PDEWithLoss",
"FlowWithLoss", "SteadyFlowWithLoss", "UnsteadyFlowWithLoss"]
-
-__all__.sort()
diff --git a/mindscience/pde/flow_with_loss.py b/mindscience/pde/flow_with_loss.py
index 16983dfba31576091c6ea74ae60c4d35f64751b6..8ff6b29166644337f45f6ebaae9fe493e3a45778 100644
--- a/mindscience/pde/flow_with_loss.py
+++ b/mindscience/pde/flow_with_loss.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""flow with loss"""
from mindspore import nn, ops, jit_class
-from ..core import get_loss_metric
+from ..common import get_loss_metric
from ..utils.check_func import check_param_type
diff --git a/mindscience/pde/pde_with_loss.py b/mindscience/pde/pde_with_loss.py
index 04792af2e18f24f56cd2590d2c2e48426eb3f0b9..a7d772992e18eae988b9b14f4e41bc1745af7609 100644
--- a/mindscience/pde/pde_with_loss.py
+++ b/mindscience/pde/pde_with_loss.py
@@ -21,8 +21,8 @@ import numpy as np
from mindspore import jit_class
from .sympy2mindspore import sympy_to_mindspore
-from ..core import batched_hessian, batched_jacobian
-from ..core import get_loss_metric
+from ..common import batched_hessian, batched_jacobian
+from ..common import get_loss_metric
@jit_class
diff --git a/mindscience/sciops/__init__.py b/mindscience/sciops/__init__.py
index 705e9c2bc4bc2b34e6bfc28893965a46e210d971..b3b71a07714f10ae140f16e0f9e6448ee7d82b54 100644
--- a/mindscience/sciops/__init__.py
+++ b/mindscience/sciops/__init__.py
@@ -15,4 +15,6 @@
"""
init
"""
+from fourier import RDFTn, IRDFTn, DFTn, IDFTn, DCT, IDCT, DST, IDST
+__all__ = ["RDFTn", "IRDFTn", "DFTn", "IDFTn", "DCT", "IDCT", "DST", "IDST"]
diff --git a/mindscience/sciops/fourier.py b/mindscience/sciops/fourier.py
new file mode 100644
index 0000000000000000000000000000000000000000..64f980668abda0e4b045f4645c35202cb0b24445
--- /dev/null
+++ b/mindscience/sciops/fourier.py
@@ -0,0 +1,656 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+''' provide complex dft based on the real dft API in mindflow.dft '''
+import numpy as np
+import scipy
+import mindspore as ms
+import mindspore.common.dtype as mstype
+from mindspore import nn, ops, Tensor, mint
+from mindspore.common.initializer import Zero
+from mindspore.ops import operations as P
+
+from ..utils.check_func import check_param_no_greater, check_param_value
+
+
+class MyRoll(nn.Cell):
+ ''' Custom defined roll operator to avoid bug in MindSpore '''
+ def __init__(self):
+ super().__init__()
+
+ if ms.get_context('device_target') == 'Ascend' and ms.get_context('mode') == ms.GRAPH_MODE:
+ self.roller = mint.roll
+ else:
+ self.roller = None
+
+ def construct(self, x, shifts, dims):
+ ''' Same as mint.roll '''
+ shifts = np.atleast_1d(shifts).astype(int).tolist()
+ dims = np.atleast_1d(dims).astype(int).tolist()
+
+ if self.roller:
+ return self.roller(x, shifts, dims)
+
+ for i, j in zip(shifts, dims):
+ n = x.shape[j]
+ x = ops.swapaxes(x, j, 0)
+ x = ops.cat([x[n - i % n:], x[:n - i % n]], axis=0)
+ x = ops.swapaxes(x, j, 0)
+ return x
+
+class MyFlip(nn.Cell):
+ ''' Custom defined flip operator to avoid bug in MindSpore '''
+ def __init__(self):
+ super().__init__()
+ msver = tuple([int(s) for s in ms.__version__.split('.')])
+
+ if msver <= (2, 4, 0) and \
+ ms.get_context('device_target') == 'Ascend' and \
+ ms.get_context('mode') == ms.PYNATIVE_MODE:
+ self.fliper = None
+ else:
+ self.fliper = mint.flip
+
+ def construct(self, x, dims):
+ ''' same as mint.flip '''
+ dims = np.atleast_1d(dims).astype(int).tolist()
+
+ if self.fliper:
+ return self.fliper(x, dims)
+
+ for j in dims:
+ x = ops.swapaxes(x, j, 0)
+ x = x[::-1]
+ x = ops.swapaxes(x, j, 0)
+ return x
+
+
+def convert_shape(shape):
+ ''' convert shape to suitable format '''
+ if isinstance(shape, int):
+ n = shape
+ elif len(shape) == 1:
+ n, = shape
+ else:
+ raise TypeError("Only support 1D dct/dst, but got shape {}".format(shape))
+ return n
+
+
+def convert_params(shape, modes, dim):
+ ''' convert input arguments to suitable format '''
+ shape = tuple(np.atleast_1d(shape).astype(int).tolist())
+ ndim = len(shape)
+
+ if dim is None:
+ dim = tuple([n - ndim for n in range(ndim)])
+ else:
+ dim = tuple(np.atleast_1d(dim).astype(int).tolist())
+
+ if modes is None or isinstance(modes, int):
+ modes = tuple([modes] * ndim)
+ else:
+ modes = tuple(np.atleast_1d(modes).astype(int).tolist())
+
+ return shape, modes, dim
+
+
+def check_params(shape, modes, dim):
+ ''' check lawfulness of input arguments '''
+ check_param_no_greater(len(dim), "dim length", 3)
+ check_param_value(len(shape), "shape length", len(dim))
+ check_param_value(len(modes), "modes length", len(dim))
+ if np.any(modes):
+ for i, (m, n) in enumerate(zip(modes, shape)):
+ # if for last axis mode need to be n//2+1, mode should be set to None
+ check_param_no_greater(m, f'mode{i+1}', n // 2)
+
+
+class _DFT1d(nn.Cell):
+ '''One dimensional Discrete Fourier Transformation'''
+
+ def __init__(self, n, mode, last_index, idx=0, scale='sqrtn', inv=False, compute_dtype=mstype.float32):
+ super().__init__()
+
+ self.n = n
+ self.dft_mat = scipy.linalg.dft(n, scale=scale)
+ self.last_index = last_index
+ self.inv = inv
+ self.odd = bool(n % 2)
+ self.idx = idx
+ self.mode_upper = mode if mode else n // 2 + (self.last_index or self.odd)
+ self.mode_lower = mode if mode else n - self.mode_upper
+ self.compute_dtype = compute_dtype
+
+ # generate DFT matrix for positive and negative frequencies
+ dft_mat_mode = self.dft_mat[:, :self.mode_upper]
+ self.a_re_upper = Tensor(dft_mat_mode.real, dtype=compute_dtype)
+ self.a_im_upper = Tensor(dft_mat_mode.imag, dtype=compute_dtype)
+
+ dft_mat_mode = self.dft_mat[:, -self.mode_lower:]
+ self.a_re_lower = Tensor(dft_mat_mode.real, dtype=compute_dtype)
+ self.a_im_lower = Tensor(dft_mat_mode.imag, dtype=compute_dtype)
+
+ # the zero matrix to fill the un-transformed modes
+ m = self.n - (self.mode_upper + self.mode_lower)
+ if m > 0:
+ self.mat = Tensor(shape=m, dtype=compute_dtype, init=Zero())
+
+ self.concat = ops.Concat(axis=-1)
+ self.cast = P.Cast()
+
+ if self.inv:
+ self.a_re_upper = self.a_re_upper.T
+ self.a_im_upper = -self.a_im_upper.T
+ self.a_re_lower = self.a_re_lower.T
+ self.a_im_lower = -self.a_im_lower.T
+
+ # last axis is real-transformed, so the inverse is conjugate of the positive frequencies
+ if last_index:
+ mode_res = min(self.mode_lower, self.mode_upper - 1)
+ dft_mat_res = self.dft_mat[:, -mode_res:]
+ a_re_res = MyFlip()(Tensor(dft_mat_res.real, dtype=compute_dtype), dims=-1)
+ a_im_res = MyFlip()(Tensor(dft_mat_res.imag, dtype=compute_dtype), dims=-1)
+
+ a_re_res = ops.pad(a_re_res, (1, self.mode_upper - mode_res - 1))
+ a_im_res = ops.pad(a_im_res, (1, self.mode_upper - mode_res - 1))
+
+ self.a_re_upper += a_re_res.T
+ self.a_im_upper += a_im_res.T
+
+ def swap_axes(self, x_re, x_im):
+ return x_re.swapaxes(-1, self.idx), x_im.swapaxes(-1, self.idx)
+
+ def complex_matmul(self, x_re, x_im, a_re, a_im):
+ y_re = ops.matmul(x_re, a_re) - ops.matmul(x_im, a_im)
+ y_im = ops.matmul(x_im, a_re) + ops.matmul(x_re, a_im)
+ return y_re, y_im
+
+ def zero_mat(self, dims):
+ mat = self.mat
+ for n in dims[::-1]:
+ mat = mint.repeat_interleave(mat.expand_dims(0), n, 0)
+ return mat
+
+ def compute_forward(self, x_re, x_im):
+ ''' Forward transform for rdft '''
+ y_re, y_im = self.complex_matmul(
+ x_re=x_re, x_im=x_im, a_re=self.a_re_upper, a_im=self.a_im_upper)
+
+ if self.last_index:
+ return y_re, y_im
+
+ y_re2, y_im2 = self.complex_matmul(
+ x_re=x_re, x_im=x_im, a_re=self.a_re_lower, a_im=self.a_im_lower)
+
+ if self.n == self.mode_upper + self.mode_lower:
+ y_re = self.concat((y_re, y_re2))
+ y_im = self.concat((y_im, y_im2))
+ else:
+ mat = self.zero_mat(x_re.shape[:-1])
+ y_re = self.concat((y_re, mat, y_re2))
+ y_im = self.concat((y_im, mat, y_im2))
+
+ return y_re, y_im
+
+ def compute_inverse(self, x_re, x_im):
+ ''' Inverse transform for irdft '''
+ y_re, y_im = self.complex_matmul(x_re=x_re[..., :self.mode_upper],
+ x_im=x_im[..., :self.mode_upper],
+ a_re=self.a_re_upper,
+ a_im=self.a_im_upper)
+ if self.last_index:
+ return y_re, y_im
+
+ y_re_res, y_im_res = self.complex_matmul(x_re=x_re[..., -self.mode_lower:],
+ x_im=x_im[..., -self.mode_lower:],
+ a_re=self.a_re_lower,
+ a_im=self.a_im_lower)
+ return y_re + y_re_res, y_im + y_im_res
+
+ def construct(self, x):
+ ''' perform 1d rdft/irdft with matmul operations '''
+ x_re, x_im = x
+ x_re, x_im = self.cast(x_re, self.compute_dtype), self.cast(x_im, self.compute_dtype)
+ x_re, x_im = self.swap_axes(x_re, x_im)
+ if self.inv:
+ y_re, y_im = self.compute_inverse(x_re, x_im)
+ else:
+ y_re, y_im = self.compute_forward(x_re, x_im)
+ y_re, y_im = self.swap_axes(y_re, y_im)
+ return y_re, y_im
+
+
+class _DFTn(nn.Cell):
+ ''' Base class for n-D DFT transform '''
+ def __init__(self, shape, dim=None, norm='backward', modes=None, compute_dtype=mstype.float32):
+ super().__init__()
+
+ shape, modes, dim = convert_params(shape, modes, dim)
+ check_params(shape, modes, dim)
+
+ ndim = len(shape)
+ inv, scale, r2c_flags = self.set_options(ndim, norm)
+ self.dft1_seq = nn.SequentialCell()
+ for n, m, r, d in zip(shape, modes, r2c_flags, dim):
+ self.dft1_seq.append(_DFT1d(
+ n=n, mode=m, last_index=r, idx=d, scale=scale, inv=inv, compute_dtype=compute_dtype))
+
+ def set_options(self, ndim, norm):
+ '''
+ Choose the dimensions, normalization, and transformation mode (forward/backward).
+ Derivative APIs overwrite the options to achieve their specific goals.
+ '''
+ inv = False
+ scale = {
+ 'backward': None,
+ 'forward': 'n',
+ 'ortho': 'sqrtn',
+ }[norm]
+ r2c_flags = np.zeros(ndim, dtype=bool).tolist()
+ r2c_flags[-1] = True
+ return inv, scale, r2c_flags
+
+ def construct(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class RDFTn(_DFTn):
+ r"""
+ 1/2/3D discrete real Fourier transformation on real number. The results should be same as
+ `scipy.fft.rfftn() `_ .
+
+ Args:
+ shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included.
+ dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed.
+ norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward',
+ same as torch.fft.rfftn
+ modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the
+ dimension of input 'x'.
+ compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32.
+
+ Inputs:
+ - **ar** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`.
+
+ Outputs:
+ - **br** (Tensor) - Real part of the output tensor, with trailing dimensions aligned with `shape`,
+ except for the last dimension, which should be shape[-1] / 2 + 1.
+ - **bi** (Tensor) - Imag part of the output tensor, with trailing dimensions aligned with `shape`,
+ except for the last dimension, which should be shape[-1] / 2 + 1.
+
+ Supported Platforms:
+ ``Ascend`` ``CPU``
+
+ Examples:
+ >>> from mindspore import ops
+ >>> from mindflow.core import RDFTn
+ >>> ar = ops.rand((2, 32, 512))
+ >>> dft_cell = RDFTn(x.shape[-2:])
+ >>> br, bi = dft_cell(ar)
+ >>> print(br.shape)
+ (2, 32, 257)
+ """
+ def construct(self, ar):
+ ''' perform n-dimensional rDFT on real tensor '''
+ # n-D Fourier transform with last axis being real-transformed, output dimension (..., m, n//2+1)
+ # the last ndim dimensions of ar must accord with shape
+ return self.dft1_seq((ar, ar * 0))
+
+
+class IRDFTn(_DFTn):
+ r"""
+ 1/2/3D discrete inverse real Fourier transformation on complex number. The results should be same as
+ `scipy.fft.irfftn() `_ .
+
+ Args:
+ shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included.
+ dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed.
+ norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward',
+ same as torch.fft.irfftn
+ modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the
+ dimension of input 'x'.
+ compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32.
+
+ Inputs:
+ - **ar** (Tensor) - Real part of the tensor to be transformed, with trailing dimensions aligned with `shape`,
+ except for the last dimension, which should be shape[-1] / 2 + 1.
+ - **ai** (Tensor) - Imag part of the tensor to be transformed, with trailing dimensions aligned with `shape`,
+ except for the last dimension, which should be shape[-1] / 2 + 1.
+
+ Outputs:
+ - **br** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`.
+
+ Supported Platforms:
+ ``Ascend`` ``CPU``
+
+ Examples:
+ >>> from mindspore import ops
+ >>> from mindflow.core import IRDFTn
+ >>> ar = ops.rand((2, 32, 257))
+ >>> ai = ops.rand((2, 32, 257))
+ >>> dft_cell = IRDFTn(x.shape[-2:])
+ >>> br = dft_cell(ar)
+ >>> print(br.shape)
+ (2, 32, 512)
+ """
+ def set_options(self, ndim, norm):
+ inv = True
+ scale = {
+ 'forward': None,
+ 'backward': 'n',
+ 'ortho': 'sqrtn',
+ }[norm]
+ r2c_flags = np.zeros(ndim, dtype=bool).tolist()
+ r2c_flags[-1] = True
+ return inv, scale, r2c_flags
+
+ def construct(self, ar, ai):
+ ''' perform n-dimensional irDFT on complex tensor and output real tensor '''
+ return self.dft1_seq((ar, ai))[0]
+
+
+class DFTn(_DFTn):
+ r"""
+ 1/2/3D discrete Fourier transformation on complex number. The results should be same as
+ `scipy.fft.fftn() `_ .
+
+ Args:
+ shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included.
+ dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed.
+ norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward',
+ same as torch.fft.irfftn
+ modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the
+ dimension of input 'x'.
+ compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32.
+
+ Inputs:
+ - **ar** (Tensor) - Real part of the tensor to be transformed, with trailing dimensions aligned with `shape`.
+ - **ai** (Tensor) - Imag part of the tensor to be transformed, with trailing dimensions aligned with `shape`.
+
+ Outputs:
+ - **br** (Tensor) - Real part of the output tensor, with trailing dimensions aligned with `shape`.
+ - **bi** (Tensor) - Imag part of the output tensor, with trailing dimensions aligned with `shape`.
+
+ Supported Platforms:
+ ``Ascend`` ``CPU``
+
+ Examples:
+ >>> from mindspore import ops
+ >>> from mindflow.cell import DFTn
+ >>> ar = ops.rand((2, 32, 512))
+ >>> ai = ops.rand((2, 32, 512))
+ >>> dft_cell = DFTn(x.shape[-2:])
+ >>> br, bi = dft_cell(ar, ai)
+ >>> print(br.shape)
+ (2, 32, 512)
+ """
+ def set_options(self, ndim, norm):
+ inv = False
+ scale = {
+ 'forward': 'n',
+ 'backward': None,
+ 'ortho': 'sqrtn',
+ }[norm]
+ r2c_flags = np.zeros(ndim, dtype=bool).tolist()
+ return inv, scale, r2c_flags
+
+ def construct(self, ar, ai):
+ ''' perform n-dimensional DFT on complex tensor '''
+ # n-D complex Fourier transform, output dimension (..., m, n)
+ return self.dft1_seq((ar, ai))
+
+
+class IDFTn(DFTn):
+ r"""
+ 1/2/3D discrete inverse Fourier transformation on complex number. The results should be same as
+ `scipy.fft.ifftn() `_ .
+
+ Args:
+ shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included.
+ dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed.
+ norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward',
+ same as torch.fft.irfftn
+ modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the
+ dimension of input 'x'.
+ compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32.
+
+ Inputs:
+ - **ar** (Tensor) - Real part of the tensor to be transformed, with trailing dimensions aligned with `shape`.
+ - **ai** (Tensor) - Imag part of the tensor to be transformed, with trailing dimensions aligned with `shape`.
+
+ Outputs:
+ - **br** (Tensor) - Real part of the output tensor, with trailing dimensions aligned with `shape`.
+ - **bi** (Tensor) - Imag part of the output tensor, with trailing dimensions aligned with `shape`.
+
+ Supported Platforms:
+ ``Ascend`` ``CPU``
+
+ Examples:
+ >>> from mindspore import ops
+ >>> from mindflow.cell import DFTn
+ >>> ar = ops.rand((2, 32, 512))
+ >>> ai = ops.rand((2, 32, 512))
+ >>> dft_cell = DFTn(x.shape[-2:])
+ >>> br, bi = dft_cell(ar, ai)
+ >>> print(br.shape)
+ (2, 32, 512)
+ """
+ def set_options(self, ndim, norm):
+ inv = True
+ scale = {
+ 'forward': None,
+ 'backward': 'n',
+ 'ortho': 'sqrtn',
+ }[norm]
+ r2c_flags = np.zeros(ndim, dtype=bool).tolist()
+ return inv, scale, r2c_flags
+
+
+class DCT(nn.Cell):
+ r"""
+ 1D discrete cosine transformation on real number on the last axis. The results should be same as
+ `scipy.fft.dct() `_ .
+ Reference: `Type 2 DCT using N FFT (Makhoul) `_ .
+
+ Args:
+ shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included.
+ Must be a length-1 tuple.
+ compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32.
+
+ Inputs:
+ - **a** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`.
+
+ Outputs:
+ - **b** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`.
+
+ Supported Platforms:
+ ``Ascend`` ``CPU``
+
+ Examples:
+ >>> from mindspore import ops
+ >>> from mindflow.cell import DCT
+ >>> a = ops.rand((2, 32, 512))
+ >>> dft_cell = DCT(x.shape[-1:])
+ >>> b = dft_cell(a)
+ >>> print(b.shape)
+ (2, 32, 512)
+ """
+ def __init__(self, shape, compute_dtype=mstype.float32):
+ super().__init__()
+
+ n = convert_shape(shape)
+
+ self.dft_cell = DFTn(n, compute_dtype=compute_dtype)
+
+ w = Tensor(np.arange(n) * np.pi / (2 * n), dtype=compute_dtype)
+ self.cosw = ops.cos(w)
+ self.sinw = ops.sin(w)
+
+ self.fliper = MyFlip()
+
+ def construct(self, a):
+ ''' perform 1-dimensional DCT on real tensor '''
+ b_half1 = a[..., ::2]
+ b_half2 = self.fliper(a[..., 1::2], dims=-1)
+ b = ops.cat([b_half1, b_half2], axis=-1)
+ cr, ci = self.dft_cell(b, b * 0)
+ return 2 * (cr * self.cosw + ci * self.sinw)
+
+
+class IDCT(nn.Cell):
+ r"""
+ 1D inverse discrete cosine transformation on real number on the last axis. The results should be same as
+ `scipy.fft.dct() `_ .
+ Reference: `A fast cosine transform in one and two dimensions
+ `_ .
+
+ Args:
+ shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included.
+ Must be a length-1 tuple.
+ compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32.
+
+ Inputs:
+ - **a** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`.
+
+ Outputs:
+ - **b** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`.
+
+ Supported Platforms:
+ ``Ascend`` ``CPU``
+
+ Examples:
+ >>> from mindspore import ops
+ >>> from mindflow.cell import IDCT
+ >>> a = ops.rand((2, 32, 512))
+ >>> dft_cell = IDCT(x.shape[-1:])
+ >>> b = dft_cell(a)
+ >>> print(b.shape)
+ (2, 32, 512)
+ """
+ def __init__(self, shape, compute_dtype=mstype.float32):
+ super().__init__()
+
+ n = convert_shape(shape)
+
+ # assert n % 2 == 0, 'only support even length' # n has to be even, or IRDFTn would fail
+
+ self.dft_cell = IRDFTn(n, compute_dtype=compute_dtype)
+
+ w = Tensor(np.arange(n // 2 + 1) * np.pi / (2 * n), dtype=compute_dtype)
+ self.cosw = ops.cos(w)
+ self.sinw = ops.sin(w)
+
+ self.fliper = MyFlip()
+
+ def construct(self, a):
+ ''' perform 1-dimensional iDCT on real tensor '''
+ n = a.shape[-1]
+
+ br = a[..., :n // 2 + 1]
+ bi = ops.pad(self.fliper(- a[..., -(n // 2):], dims=-1), (1, 0))
+ vr = (br * self.cosw - bi * self.sinw) / 2
+ vi = (bi * self.cosw + br * self.sinw) / 2
+
+ c = self.dft_cell(vr, vi) # (..., n)
+ c1 = c[..., :(n + 1) // 2]
+ c2 = self.fliper(c[..., (n + 1) // 2:], dims=-1)
+ d1 = ops.pad(c1.reshape(-1)[..., None], (0, 1)).reshape(*c1.shape[:-1], -1)
+ d2 = ops.pad(c2.reshape(-1)[..., None], (1, 0)).reshape(*c2.shape[:-1], -1)
+ # in case n is odd, d1 and d2 need to be aligned
+ d1 = d1[..., :n]
+ d2 = ops.pad(d2, (0, n % 2))
+ return d1 + d2
+
+
+class DST(nn.Cell):
+ r"""
+ 1D discrete sine transformation on real number on the last axis. The results should be same as
+ `scipy.fft.dct() `_ .
+ Reference: `Wikipedia `_ .
+
+ Args:
+ shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included.
+ Must be a length-1 tuple.
+ compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32.
+
+ Inputs:
+ - **a** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`.
+
+ Outputs:
+ - **b** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`.
+
+ Supported Platforms:
+ ``Ascend`` ``CPU``
+
+ Examples:
+ >>> from mindspore import ops
+ >>> from mindflow.cell import DST
+ >>> a = ops.rand((2, 32, 512))
+ >>> dft_cell = DST(x.shape[-1:])
+ >>> b = dft_cell(a)
+ >>> print(b.shape)
+ (2, 32, 512)
+ """
+ def __init__(self, shape, compute_dtype=mstype.float32):
+ super().__init__()
+ n = convert_shape(shape)
+ self.dft_cell = DCT(n, compute_dtype=compute_dtype)
+ multiplier = np.ones(n)
+ multiplier[..., 1::2] *= -1
+ self.multiplier = Tensor(multiplier, dtype=compute_dtype)
+
+ def construct(self, a):
+ ''' perform 1-dimensional DST on real tensor '''
+ return self.dft_cell.fliper(self.dft_cell(a * self.multiplier), dims=-1)
+
+
+class IDST(nn.Cell):
+ r"""
+ 1D inverse discrete sine transformation on real number on the last axis. The results should be same as
+ `scipy.fft.dct() `_ .
+ Reference: `Wikipedia `_ .
+
+ Args:
+ shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included.
+ Must be a length-1 tuple.
+ compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32.
+
+ Inputs:
+ - **a** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`.
+
+ Outputs:
+ - **b** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`.
+
+ Supported Platforms:
+ ``Ascend`` ``CPU``
+
+ Examples:
+ >>> from mindspore import ops
+ >>> from mindflow.cell import IDST
+ >>> a = ops.rand((2, 32, 512))
+ >>> dft_cell = IDST(x.shape[-1:])
+ >>> b = dft_cell(a)
+ >>> print(b.shape)
+ (2, 32, 512)
+ """
+ def __init__(self, shape, compute_dtype=mstype.float32):
+ super().__init__()
+ n = convert_shape(shape)
+ self.dft_cell = IDCT(n, compute_dtype=compute_dtype)
+ multiplier = np.ones(n)
+ multiplier[..., 1::2] *= -1
+ self.multiplier = Tensor(multiplier, dtype=compute_dtype)
+
+ def construct(self, a):
+ ''' perform 1-dimensional iDST on real tensor '''
+ return self.dft_cell(self.dft_cell.fliper(a, dims=-1)) * self.multiplier
diff --git a/tests/common/test_optimizers.py b/tests/common/test_optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..4085507fd42dcb9b85facc4ea593c5fb03b249fe
--- /dev/null
+++ b/tests/common/test_optimizers.py
@@ -0,0 +1,275 @@
+# ============================================================================
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Optimizers Test Case"""
+import os
+import random
+import sys
+
+import pytest
+import numpy as np
+
+import mindspore as ms
+from mindspore import ops, set_seed, nn, mint
+from mindspore import dtype as mstype
+from mindflow import UNet2D, TransformerBlock, MultiHeadAttention, AdaHessian
+from mindflow.cell.attention import FeedForward
+from mindflow.cell.unet2d import Down
+
+PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
+sys.path.append(PROJECT_ROOT)
+
+# pylint: disable=wrong-import-position
+
+from common.cell import FP32_RTOL
+
+# pylint: enable=wrong-import-position
+
+set_seed(0)
+np.random.seed(0)
+random.seed(0)
+
+
+class TestAdaHessianAccuracy(AdaHessian):
+ ''' Child class for testing the accuracy of AdaHessian optimizer '''
+
+ def gen_rand_vecs(self, grads):
+ ''' generate certain vector for accuracy test '''
+ return [ms.Tensor(np.arange(p.size).reshape(p.shape) - p.size // 2, dtype=ms.float32) for p in grads]
+
+
+class TestUNet2D(UNet2D):
+ ''' Child class for testing optimizing UNet with AdaHessian '''
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ class TestDown(Down):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ in_channels = args[0]
+ kernel_size = kwargs['kernel_size']
+ stride = kwargs['stride']
+ # replace the `maxpool` layer in the original UNet with `conv` to avoid `vjp` problem
+ self.maxpool = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride)
+
+ self.layers_down = nn.CellList()
+ for i in range(self.n_layers):
+ self.layers_down.append(TestDown(self.base_channels * 2**i, self.base_channels * 2 ** (i+1),
+ kernel_size=self.kernel_size, stride=self.stride,
+ activation=self.activation, enable_bn=self.enable_bn))
+
+
+class TestAttentionBlock(TransformerBlock):
+ ''' Child class for testing optimizing Attention with AdaHessian '''
+
+ def __init__(self,
+ in_channels: int,
+ num_heads: int,
+ enable_flash_attn: bool = False,
+ fa_dtype: mstype = mstype.bfloat16,
+ drop_mode: str = "dropout",
+ dropout_rate: float = 0.0,
+ compute_dtype: mstype = mstype.float32,
+ ):
+ super().__init__(in_channels=in_channels,
+ num_heads=num_heads,
+ enable_flash_attn=enable_flash_attn,
+ fa_dtype=fa_dtype,
+ drop_mode=drop_mode,
+ dropout_rate=dropout_rate,
+ compute_dtype=compute_dtype,
+ )
+
+ class TestMlp(FeedForward):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.act_fn = nn.ReLU() # replace `gelu` with `relu` to avoid `vjp` problem
+
+ class TestMultiHeadAttention(MultiHeadAttention):
+ ''' MultiHeadAttention modified to support vjp '''
+ def get_qkv(self, x: ms.Tensor) -> tuple[ms.Tensor]:
+ ''' use masks to select out q, k, v, instead of tensor reshaping & indexing '''
+ b, n, c_full = x.shape
+ c = c_full // self.num_heads
+
+ # use matmul with masks to select out q, k, v to avoid vjp problem
+ q_mask = ms.Tensor(np.vstack([np.eye(c), np.zeros([2 * c, c])]), dtype=self.compute_dtype)
+ k_mask = ms.Tensor(np.vstack([np.zeros([c, c]), np.eye(c), np.zeros([c, c])]), dtype=self.compute_dtype)
+ v_mask = ms.Tensor(np.vstack([np.zeros([2 * c, c]), np.eye(c)]), dtype=self.compute_dtype)
+
+ qkv = self.qkv(x)
+ qkv = qkv.reshape(b, n, self.num_heads, -1).swapaxes(1, 2)
+
+ q = mint.matmul(qkv, q_mask)
+ k = mint.matmul(qkv, k_mask)
+ v = mint.matmul(qkv, v_mask)
+
+ return q, k, v
+
+ self.ffn = TestMlp(
+ in_channels=in_channels,
+ dropout_rate=dropout_rate,
+ compute_dtype=compute_dtype,
+ )
+ self.attention = TestMultiHeadAttention(
+ in_channels=in_channels,
+ num_heads=num_heads,
+ enable_flash_attn=enable_flash_attn,
+ fa_dtype=fa_dtype,
+ drop_mode=drop_mode,
+ dropout_rate=dropout_rate,
+ compute_dtype=compute_dtype,
+ )
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+def test_adahessian_accuracy(mode):
+ """
+ Feature: AdaHessian forward accuracy test
+ Description: Test the accuracy of the AdaHessian optimizer in both GRAPH_MODE and PYNATIVE_MODE
+ with input data specified in the code below.
+ The expected output is compared to a reference output stored in
+ './mindflow/core/optimizers/data/adahessian_output.npy'.
+ Expectation: The output should match the target data within the defined relative tolerance,
+ ensuring the AdaHessian computation is accurate.
+ """
+ ms.set_context(mode=mode)
+
+ weight_init = ms.Tensor(np.reshape(range(72), [4, 2, 3, 3]), dtype=ms.float32)
+ bias_init = ms.Tensor(np.arange(4), dtype=ms.float32)
+
+ net = nn.Conv2d(
+ in_channels=2, out_channels=4, kernel_size=3, has_bias=True, weight_init=weight_init, bias_init=bias_init)
+
+ def forward(a):
+ return ops.sqrt(ops.mean(ops.square(net(a))))
+
+ grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params())
+
+ optimizer = TestAdaHessianAccuracy(
+ net.trainable_params(),
+ learning_rate=0.1, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.)
+
+ inputs = ms.Tensor(np.reshape(range(100), [2, 2, 5, 5]), dtype=ms.float32)
+
+ for _ in range(4):
+ optimizer(grad_fn, inputs)
+
+ outputs = net(inputs).numpy()
+ outputs_ref = np.load('/home/workspace/mindspore_dataset/mindscience/mindflow/optimizers/adahessian_output.npy')
+ relative_error = np.max(np.abs(outputs - outputs_ref)) / np.max(np.abs(outputs_ref))
+ assert relative_error < FP32_RTOL, "The verification of adahessian accuracy is not successful."
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+@pytest.mark.parametrize('model_option', ['unet', 'attention'])
+def test_adahessian_st(mode, model_option):
+ """
+ Feature: AdaHessian ST test
+ Description: Test the function of the AdaHessian optimizer in both GRAPH_MODE and PYNATIVE_MODE
+ on the complex network such as UNet. The input is a Tensor specified in the code
+ and the output is the loss after 4 rounds of optimization.
+ Expectation: The output should be finite, ensuring the AdaHessian runs successfully on UNet.
+ """
+ ms.set_context(mode=mode)
+
+ # default test with Attention network
+ net = TestAttentionBlock(in_channels=256, num_heads=4)
+ inputs = ms.Tensor(np.sin(np.arange(102400)).reshape(4, 100, 256), dtype=ms.float32)
+
+ # test with UNet network
+ if model_option.lower() == 'unet':
+ net = TestUNet2D(
+ in_channels=2,
+ out_channels=4,
+ base_channels=8,
+ n_layers=4,
+ kernel_size=2,
+ stride=2,
+ activation='relu',
+ data_format="NCHW",
+ enable_bn=False, # bn leads to bug in PYNATIVE_MODE for MS2.5.0
+ )
+ inputs = ms.Tensor(np.random.rand(2, 2, 64, 64), dtype=ms.float32)
+
+ def forward(a):
+ return ops.sqrt(ops.mean(ops.square(net(a))))
+
+ grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params())
+
+ optimizer = AdaHessian(
+ net.trainable_params(),
+ learning_rate=0.1, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.)
+
+ for _ in range(4):
+ optimizer(grad_fn, inputs)
+
+ loss = forward(inputs)
+ assert ops.isfinite(loss)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE])
+def test_adahessian_compare(mode):
+ """
+ Feature: AdaHessian compare with Adam
+ Description: Compare the algorithm results of the AdaHessian optimizer with Adam.
+ The code runs in PYNATIVE_MODE and the network under comparison is TransformerBlock.
+ The optimization runs 100 rounds to demonstrate an essential loss decrease.
+ Expectation: The loss of AdaHessian outperforms Adam by 20% under the same configuration on an Attention network.
+ """
+ ms.set_context(mode=mode)
+
+ def get_loss(optimizer_option):
+ ''' compare Adam and AdaHessian '''
+ net = TestAttentionBlock(in_channels=256, num_heads=4)
+ inputs = ms.Tensor(np.sin(np.arange(102400)).reshape(4, 100, 256), dtype=ms.float32)
+
+ def forward(a):
+ return ops.sqrt(ops.mean(ops.square(net(a))))
+
+ grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params())
+
+ if optimizer_option.lower() == 'adam':
+ optimizer = nn.Adam(
+ net.trainable_params(),
+ learning_rate=0.01, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.)
+ else:
+ optimizer = AdaHessian(
+ net.trainable_params(),
+ learning_rate=0.01, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.)
+
+ for _ in range(20):
+ if optimizer_option.lower() == 'adam':
+ optimizer(grad_fn(inputs))
+ else:
+ optimizer(grad_fn, inputs)
+
+ loss = forward(inputs)
+ return loss
+
+ loss_adam = get_loss('adam')
+ loss_adahessian = get_loss('adahessian')
+
+ assert loss_adam * 0.8 > loss_adahessian, (loss_adam, loss_adahessian)
diff --git a/tests/models/ffno/test_ffno.py b/tests/models/ffno/test_ffno.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cda98808dc60934f3e44d050638a093de4340d3
--- /dev/null
+++ b/tests/models/ffno/test_ffno.py
@@ -0,0 +1,381 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""mindflow st testcase"""
+
+import os
+import sys
+import time
+
+import pytest
+import numpy as np
+
+import mindspore as ms
+from mindspore import nn, Tensor, set_seed, load_param_into_net, load_checkpoint
+from mindspore import dtype as mstype
+
+from mindflow.cell import FFNO1D, FFNO2D, FFNO3D
+
+PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
+sys.path.append(PROJECT_ROOT)
+
+# pylint: disable=wrong-import-position
+
+from common.cell.utils import compare_output
+from common.cell import FP32_RTOL
+
+# pylint: enable=wrong-import-position
+
+set_seed(123456)
+folder_path = "/home/workspace/mindspore_dataset/mindscience/ffno"
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+def test_ffno1d_output(mode):
+ """
+ Feature: Test FFNO1D network in platform ascend.
+ Description: None.
+ Expectation: Success or throw AssertionError.
+ """
+ ms.set_context(mode=mode)
+ model1d = FFNO1D(in_channels=2,
+ out_channels=2,
+ n_modes=[2],
+ resolutions=[6],
+ hidden_channels=2,
+ n_layers=2,
+ share_weight=True,
+ r_padding=8,
+ ffno_compute_dtype=mstype.float32)
+
+ data1d = Tensor(np.load(os.path.join(folder_path, "ffno_data1d.npy")), dtype=mstype.float32)
+ param1d = load_checkpoint(os.path.join(folder_path, "ffno1d.ckpt"))
+ load_param_into_net(model1d, param1d)
+ output1d = model1d(data1d)
+ target1d = np.load(os.path.join(folder_path, "ffno_target1d.npy"))
+
+ assert output1d.shape == (2, 6, 2)
+ assert output1d.dtype == mstype.float32
+ assert compare_output(output1d.asnumpy(), target1d, rtol=FP32_RTOL, atol=FP32_RTOL)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+def test_ffno1d_mse_loss_output(mode):
+ """
+ Feature: Test FFNO1D MSE Loss in platform ascend.
+ Description: None.
+ Expectation: Success or throw AssertionError.
+ """
+ ms.set_context(mode=mode)
+ model1d = FFNO1D(in_channels=2,
+ out_channels=2,
+ n_modes=[2],
+ resolutions=[6],
+ hidden_channels=2,
+ n_layers=2,
+ share_weight=True,
+ r_padding=8,
+ ffno_compute_dtype=mstype.float32)
+
+ data1d = Tensor(np.ones((2, 6, 2)), dtype=mstype.float32)
+ label_1d = Tensor(np.ones((2, 6, 2)), dtype=mstype.float32)
+ param1d = load_checkpoint(os.path.join(folder_path, "ffno1d.ckpt"))
+ load_param_into_net(model1d, param1d)
+
+ loss_fn = nn.MSELoss()
+ optimizer_1d = nn.SGD(model1d.trainable_params(), learning_rate=0.01)
+ net_with_loss_1d = nn.WithLossCell(model1d, loss_fn)
+ train_step_1d = nn.TrainOneStepCell(net_with_loss_1d, optimizer_1d)
+
+ # calculate two steps of loss
+ loss_1d = train_step_1d(data1d, label_1d)
+ target_loss_1_1d = 0.63846040
+ assert compare_output(loss_1d.asnumpy(), target_loss_1_1d, rtol=FP32_RTOL, atol=FP32_RTOL)
+
+ loss_1d = train_step_1d(data1d, label_1d)
+ target_loss_2_1d = 0.04462930
+ assert compare_output(loss_1d.asnumpy(), target_loss_2_1d, rtol=FP32_RTOL, atol=FP32_RTOL)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+def test_ffno2d_output(mode):
+ """
+ Feature: Test FFNO2D network in platform ascend.
+ Description: None.
+ Expectation: Success or throw AssertionError.
+ """
+ ms.set_context(mode=mode)
+ model2d = FFNO2D(in_channels=2,
+ out_channels=2,
+ n_modes=[2, 2],
+ resolutions=[6, 6],
+ hidden_channels=2,
+ n_layers=2,
+ share_weight=True,
+ r_padding=8,
+ ffno_compute_dtype=mstype.float32)
+
+ data2d = Tensor(np.load(os.path.join(folder_path, "ffno_data2d.npy")), dtype=mstype.float32)
+ param2d = load_checkpoint(os.path.join(folder_path, "ffno2d.ckpt"))
+ load_param_into_net(model2d, param2d)
+ output2d = model2d(data2d)
+ target2d = np.load(os.path.join(folder_path, "ffno_target2d.npy"))
+
+ assert output2d.shape == (2, 6, 6, 2)
+ assert output2d.dtype == mstype.float32
+ assert compare_output(output2d.asnumpy(), target2d, rtol=FP32_RTOL, atol=FP32_RTOL)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+def test_ffno2d_mse_loss_output(mode):
+ """
+ Feature: Test FFNO2D MSE Loss in platform ascend.
+ Description: None.
+ Expectation: Success or throw AssertionError.
+ """
+ ms.set_context(mode=mode)
+ model2d = FFNO2D(in_channels=2,
+ out_channels=2,
+ n_modes=[2, 2],
+ resolutions=[6, 6],
+ hidden_channels=2,
+ n_layers=2,
+ share_weight=True,
+ r_padding=8,
+ ffno_compute_dtype=mstype.float32)
+
+ data2d = Tensor(np.ones((2, 6, 6, 2)), dtype=mstype.float32)
+ label_2d = Tensor(np.ones((2, 6, 6, 2)), dtype=mstype.float32)
+ param2d = load_checkpoint(os.path.join(folder_path, "ffno2d.ckpt"))
+ load_param_into_net(model2d, param2d)
+
+ loss_fn = nn.MSELoss()
+ optimizer_2d = nn.SGD(model2d.trainable_params(), learning_rate=0.01)
+ net_with_loss_2d = nn.WithLossCell(model2d, loss_fn)
+ train_step_2d = nn.TrainOneStepCell(net_with_loss_2d, optimizer_2d)
+
+ # calculate two steps of loss
+ loss_2d = train_step_2d(data2d, label_2d)
+ target_loss_1_2d = 1.70347130
+ assert compare_output(loss_2d.asnumpy(), target_loss_1_2d, rtol=FP32_RTOL, atol=FP32_RTOL)
+
+ loss_2d = train_step_2d(data2d, label_2d)
+ target_loss_2_2d = 0.28143430
+ assert compare_output(loss_2d.asnumpy(), target_loss_2_2d, rtol=FP32_RTOL, atol=FP32_RTOL)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+def test_ffno3d_output(mode):
+ """
+ Feature: Test FFNO3D network in platform ascend.
+ Description: None.
+ Expectation: Success or throw AssertionError.
+ """
+ ms.set_context(mode=mode)
+ model3d = FFNO3D(in_channels=2,
+ out_channels=2,
+ n_modes=[2, 2, 2],
+ resolutions=[6, 6, 6],
+ hidden_channels=2,
+ n_layers=2,
+ share_weight=True,
+ r_padding=8,
+ ffno_compute_dtype=mstype.float32)
+
+ data3d = Tensor(np.load(os.path.join(folder_path, "ffno_data3d.npy")), dtype=mstype.float32)
+ param3d = load_checkpoint(os.path.join(folder_path, "ffno3d.ckpt"))
+ load_param_into_net(model3d, param3d)
+ output3d = model3d(data3d)
+ target3d = np.load(os.path.join(folder_path, "ffno_target3d.npy"))
+
+ assert output3d.shape == (2, 6, 6, 6, 2)
+ assert output3d.dtype == mstype.float32
+ assert compare_output(output3d.asnumpy(), target3d, rtol=FP32_RTOL, atol=FP32_RTOL)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+def test_ffno3d_mse_loss_output(mode):
+ """
+ Feature: Test FFNO3D MSE Loss in platform ascend.
+ Description: None.
+ Expectation: Success or throw AssertionError.
+ """
+ ms.set_context(mode=mode)
+ model3d = FFNO3D(in_channels=2,
+ out_channels=2,
+ n_modes=[2, 2, 2],
+ resolutions=[6, 6, 6],
+ hidden_channels=2,
+ n_layers=2,
+ share_weight=True,
+ r_padding=8,
+ ffno_compute_dtype=mstype.float32)
+
+ data3d = Tensor(np.ones((2, 6, 6, 6, 2)), dtype=mstype.float32)
+ label_3d = Tensor(np.ones((2, 6, 6, 6, 2)), dtype=mstype.float32)
+ param3d = load_checkpoint(os.path.join(folder_path, "ffno3d.ckpt"))
+ load_param_into_net(model3d, param3d)
+
+ loss_fn = nn.MSELoss()
+ optimizer_3d = nn.SGD(model3d.trainable_params(), learning_rate=0.01)
+ net_with_loss_3d = nn.WithLossCell(model3d, loss_fn)
+ train_step_3d = nn.TrainOneStepCell(net_with_loss_3d, optimizer_3d)
+
+ # calculate two steps of loss
+ loss_3d = train_step_3d(data3d, label_3d)
+ target_loss_1_3d = 1.94374371
+ assert compare_output(loss_3d.asnumpy(), target_loss_1_3d, rtol=FP32_RTOL, atol=FP32_RTOL)
+
+ loss_3d = train_step_3d(data3d, label_3d)
+ target_loss_2_3d = 0.24034855
+ assert compare_output(loss_3d.asnumpy(), target_loss_2_3d, rtol=FP32_RTOL, atol=FP32_RTOL)
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+def test_ffno1d_speed(mode):
+ """
+ Feature: Test FFNO1D training speed in platform ascend.
+ Description: The speed of each training step.
+ Expectation: Success or throw AssertionError.
+ """
+ ms.set_context(mode=mode)
+ model1d = FFNO1D(in_channels=32,
+ out_channels=32,
+ n_modes=[16],
+ resolutions=[128],
+ hidden_channels=2,
+ n_layers=2,
+ share_weight=True,
+ r_padding=8,
+ ffno_compute_dtype=mstype.float32)
+
+ data1d = Tensor(np.ones((32, 128, 32)), dtype=mstype.float32)
+ label_1d = Tensor(np.ones((32, 128, 32)), dtype=mstype.float32)
+
+ loss_fn = nn.MSELoss()
+ optimizer_1d = nn.SGD(model1d.trainable_params(), learning_rate=0.01)
+ net_with_loss_1d = nn.WithLossCell(model1d, loss_fn)
+ train_step_1d = nn.TrainOneStepCell(net_with_loss_1d, optimizer_1d)
+
+ steps = 10
+ for _ in range(10):
+ train_step_1d(data1d, label_1d)
+
+ start_time = time.time()
+ for _ in range(10):
+ train_step_1d(data1d, label_1d)
+ end_time = time.time()
+
+ assert (end_time - start_time) / steps < 0.5
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+def test_ffno2d_speed(mode):
+ """
+ Feature: Test FFNO2D training speed in platform ascend.
+ Description: The speed of each training step.
+ Expectation: Success or throw AssertionError.
+ """
+ ms.set_context(mode=mode)
+ model2d = FFNO2D(in_channels=32,
+ out_channels=32,
+ n_modes=[16, 16],
+ resolutions=[64, 64],
+ hidden_channels=2,
+ n_layers=2,
+ share_weight=True,
+ r_padding=8,
+ ffno_compute_dtype=mstype.float32)
+
+ data2d = Tensor(np.ones((32, 64, 64, 32)), dtype=mstype.float32)
+ label_2d = Tensor(np.ones((32, 64, 64, 32)), dtype=mstype.float32)
+
+ loss_fn = nn.MSELoss()
+ optimizer_2d = nn.SGD(model2d.trainable_params(), learning_rate=0.01)
+ net_with_loss_2d = nn.WithLossCell(model2d, loss_fn)
+ train_step_2d = nn.TrainOneStepCell(net_with_loss_2d, optimizer_2d)
+
+ steps = 10
+ for _ in range(steps):
+ train_step_2d(data2d, label_2d)
+
+ start_time = time.time()
+ for _ in range(steps):
+ train_step_2d(data2d, label_2d)
+ end_time = time.time()
+
+ assert (end_time - start_time) / steps < 1
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+def test_ffno3d_speed(mode):
+ """
+ Feature: Test FFNO3D training speed in platform ascend.
+ Description: The speed of each training step.
+ Expectation: Success or throw AssertionError.
+ """
+ ms.set_context(mode=mode)
+ model3d = FFNO3D(in_channels=2,
+ out_channels=2,
+ n_modes=[16, 16, 16],
+ resolutions=[32, 32, 32],
+ hidden_channels=2,
+ n_layers=2,
+ share_weight=True,
+ r_padding=8,
+ ffno_compute_dtype=mstype.float32)
+
+ data3d = Tensor(np.ones((2, 32, 32, 32, 2)), dtype=mstype.float32)
+ label_3d = Tensor(np.ones((2, 32, 32, 32, 2)), dtype=mstype.float32)
+
+ loss_fn = nn.MSELoss()
+ optimizer_3d = nn.SGD(model3d.trainable_params(), learning_rate=0.01)
+ net_with_loss_3d = nn.WithLossCell(model3d, loss_fn)
+ train_step_3d = nn.TrainOneStepCell(net_with_loss_3d, optimizer_3d)
+
+ steps = 10
+ for _ in range(steps):
+ train_step_3d(data3d, label_3d)
+
+ start_time = time.time()
+ for _ in range(steps):
+ train_step_3d(data3d, label_3d)
+ end_time = time.time()
+
+ assert (end_time - start_time) / steps < 3
diff --git a/tests/models/fno/fno1d.yaml b/tests/models/fno/fno1d.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c3e9e8c5a675dcdc5d82e9cf8d1979e961baed2c
--- /dev/null
+++ b/tests/models/fno/fno1d.yaml
@@ -0,0 +1,27 @@
+model:
+ name: FNO1D
+ in_channels: 1
+ out_channels: 1
+ modes: 16
+ resolutions: 1024
+ hidden_channels: 10
+ depths: 1
+data:
+ name: "burgers1d"
+ root_dir: "./dataset"
+ train:
+ num_samples: 1000
+ test:
+ num_samples: 200
+ batch_size: 8
+ resolution: 1024
+ t_in: 1
+ t_out: 1
+ step: 8
+optimizer:
+ learning_rate: 0.001
+ epochs: 100
+summary:
+ test_interval: 10
+ summary_dir: "./summary"
+ ckpt_dir: "./checkpoints"
\ No newline at end of file
diff --git a/tests/models/fno/test_fno.py b/tests/models/fno/test_fno.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fbc0c437d57b553a2a592c494dadcc3ceccb3f8
--- /dev/null
+++ b/tests/models/fno/test_fno.py
@@ -0,0 +1,95 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""mindflow st testcase"""
+
+import pytest
+import numpy as np
+
+from mindspore import Tensor, context, set_seed, load_param_into_net, load_checkpoint
+from mindspore import dtype as mstype
+from mindflow.cell import FNO1D, FNO2D, FNO3D
+from mindflow.cell.neural_operators.fno_sp import SpectralConv1dDft, SpectralConv2dDft, SpectralConv3dDft
+
+RTOL = 0.001
+set_seed(123456)
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+def test_fno_output():
+ """
+ Feature: Test FNO1D, FNO2D and FNO3D network in platform gpu and ascend.
+ Description: None.
+ Expectation: Success or throw AssertionError.
+ Need to adaptive 910B
+ """
+ context.set_context(mode=context.GRAPH_MODE)
+ model1d = FNO1D(
+ in_channels=2, out_channels=2, n_modes=[2], resolutions=[6], fno_compute_dtype=mstype.float32)
+ model2d = FNO2D(
+ in_channels=2, out_channels=2, n_modes=[2, 2], resolutions=[6, 6], fno_compute_dtype=mstype.float32)
+ model3d = FNO3D(
+ in_channels=2, out_channels=2, n_modes=[2, 2, 2], resolutions=[6, 6, 6], fno_compute_dtype=mstype.float32)
+ data1d = Tensor(np.ones((2, 6, 2)), dtype=mstype.float32)
+ data2d = Tensor(np.ones((2, 6, 6, 2)), dtype=mstype.float32)
+ data3d = Tensor(np.ones((2, 6, 6, 6, 2)), dtype=mstype.float32)
+ output1d = model1d(data1d)
+ output2d = model2d(data2d)
+ output3d = model3d(data3d)
+ assert output1d.shape == (2, 6, 2)
+ assert output1d.dtype == mstype.float32
+ assert output2d.shape == (2, 6, 6, 2)
+ assert output2d.dtype == mstype.float32
+ assert output3d.shape == (2, 6, 6, 6, 2)
+ assert output3d.dtype == mstype.float32
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+def test_spectralconvdft_output():
+ """
+ Feature: Test SpectralConv1dDft, SpectralConv2dDft and SpectralConv3dDft network in platform gpu and ascend.
+ Description: None.
+ Expectation: Success or throw AssertionError.
+ """
+ context.set_context(mode=context.GRAPH_MODE)
+ model1d = SpectralConv1dDft(in_channels=2, out_channels=2, n_modes=[2], resolutions=[6])
+ model2d = SpectralConv2dDft(in_channels=2, out_channels=2, n_modes=[2, 2], resolutions=[6, 6])
+ model3d = SpectralConv3dDft(in_channels=2, out_channels=2, n_modes=[2, 2, 2], resolutions=[6, 6, 6])
+ data1d = Tensor(np.ones((2, 2, 6)), dtype=mstype.float32)
+ data2d = Tensor(np.ones((2, 2, 6, 6)), dtype=mstype.float32)
+ data3d = Tensor(np.ones((2, 2, 6, 6, 6)), dtype=mstype.float32)
+ target1d = 3.64671636
+ target2d = 35.93239212
+ target3d = 149.64256287
+ param1 = load_checkpoint("./spectralconv1d.ckpt")
+ param2 = load_checkpoint("./spectralconv2d.ckpt")
+ param3 = load_checkpoint("./spectralconv3d.ckpt")
+ load_param_into_net(model1d, param1)
+ load_param_into_net(model2d, param2)
+ load_param_into_net(model3d, param3)
+ output1d = model1d(data1d)
+ output2d = model2d(data2d)
+ output3d = model3d(data3d)
+ assert output1d.shape == (2, 2, 6)
+ assert output1d.dtype == mstype.float32
+ assert output1d.sum() - target1d < RTOL
+ assert output2d.shape == (2, 2, 6, 6)
+ assert output2d.dtype == mstype.float32
+ assert output2d.sum() - target2d < RTOL
+ assert output3d.shape == (2, 2, 6, 6, 6)
+ assert output3d.dtype == mstype.float32
+ assert output3d.sum() - target3d < RTOL
diff --git a/tests/models/fno/test_fno1d.py b/tests/models/fno/test_fno1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d3d635b60782a28dc3f1050bb99e36fb38cbc0b
--- /dev/null
+++ b/tests/models/fno/test_fno1d.py
@@ -0,0 +1,228 @@
+# ============================================================================
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""FNO1D Test Case"""
+import os
+import random
+import sys
+
+import pytest
+import numpy as np
+
+import mindspore as ms
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore import Tensor, ops, set_seed
+from mindspore import dtype as mstype
+from mindflow import FNO1D, RelativeRMSELoss, load_yaml_config
+from mindflow.pde import SteadyFlowWithLoss
+
+PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
+sys.path.append(PROJECT_ROOT)
+
+from common.cell import validate_checkpoint, compare_output
+from common.cell import FP16_RTOL, FP16_ATOL
+
+set_seed(0)
+np.random.seed(0)
+random.seed(0)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+def test_fno1d_checkpoint():
+ """
+ Feature: FNO1D checkpoint loading and verification
+ Description: Test the consistency of the FNO1D model when loading from a saved checkpoint.
+ Two FNO1D models are initialized with the same parameters, and one of them
+ loads weights from the specified checkpoint located at './mindflow/cell/fno1d/ckpt/fno1d.ckpt'.
+ The test input is a randomly generated tensor, and the validation checks if
+ both models (one with loaded parameters) produce the same outputs.
+ Expectation: The model loaded from the checkpoint should behave identically to a newly initialized
+ model with the same parameters, verifying that the checkpoint restores the model's state correctly.
+ """
+ config = load_yaml_config('./fno1d/configs/fno1d.yaml')
+ model_params = config["model"]
+ ckpt_path = './fno1d/ckpt/fno1d.ckpt'
+
+ model1 = FNO1D(in_channels=model_params["in_channels"],
+ out_channels=model_params["out_channels"],
+ n_modes=model_params["modes"],
+ resolutions=model_params["resolutions"],
+ hidden_channels=model_params["hidden_channels"],
+ n_layers=model_params["depths"],
+ projection_channels=4*model_params["hidden_channels"],
+ )
+
+ model2 = FNO1D(in_channels=model_params["in_channels"],
+ out_channels=model_params["out_channels"],
+ n_modes=model_params["modes"],
+ resolutions=model_params["resolutions"],
+ hidden_channels=model_params["hidden_channels"],
+ n_layers=model_params["depths"],
+ projection_channels=4*model_params["hidden_channels"],
+ )
+
+ params = load_checkpoint(ckpt_path)
+ load_param_into_net(model1, params)
+ test_inputs = Tensor(np.random.randn(1, 1024, 1), mstype.float32)
+
+ validate_ans = validate_checkpoint(model1, model2, (test_inputs,))
+ assert validate_ans, "The verification of FNO1D checkpoint is not successful."
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+def test_fno1d_forward_accuracy(mode):
+ """
+ Feature: FNO1D forward accuracy test
+ Description: Test the forward accuracy of the FNO1D model in both GRAPH_MODE and PYNATIVE_MODE.
+ The model is initialized with parameters from './mindflow/cell/fno1d/configs/fno1d.yaml',
+ and weights are loaded from the checkpoint located at './mindflow/cell/fno1d/ckpt/fno1d.ckpt'.
+ The input data is loaded from './mindflow/cell/fno1d/data/fno1d_input.npy', and the output
+ is compared against the expected prediction stored in './mindflow/cell/fno1d/data/fno1d_pred.npy'.
+ Expectation: The output should match the target prediction data within the specified relative and absolute
+ tolerance values, ensuring the forward pass of the FNO1D model is accurate.
+ """
+ ms.set_context(mode=mode)
+ config = load_yaml_config('./fno1d/configs/fno1d.yaml')
+ model_params = config["model"]
+ ckpt_path = './fno1d/ckpt/fno1d.ckpt'
+
+ model = FNO1D(in_channels=model_params["in_channels"],
+ out_channels=model_params["out_channels"],
+ n_modes=model_params["modes"],
+ resolutions=model_params["resolutions"],
+ hidden_channels=model_params["hidden_channels"],
+ n_layers=model_params["depths"],
+ projection_channels=4*model_params["hidden_channels"],
+ )
+
+ params = load_checkpoint(ckpt_path)
+ load_param_into_net(model, params)
+ input_data = np.load('./fno1d/data/fno1d_input.npy')
+ test_inputs = Tensor(input_data, mstype.float32)
+ output = model(test_inputs)
+ output = output.asnumpy()
+ output_target = np.load('./fno1d/data/fno1d_pred.npy')
+ validate_ans = compare_output(output, output_target, rtol=FP16_RTOL, atol=FP16_ATOL)
+ assert validate_ans, "The verification of FNO1D forward accuracy is not successful."
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+def test_fno1d_amp():
+ """
+ Feature: FNO1D AMP (Automatic Mixed Precision) accuracy test
+ Description: Test the accuracy of FNO1D model with and without AMP (Automatic Mixed Precision).
+ Two FNO1D models are initialized with identical parameters. The first model uses the
+ default precision(float16), while the second model is set to use float32 precision for computation.
+ Both models load the same checkpoint from './mindflow/cell/fno1d/ckpt/fno1d.ckpt'.
+ The input data is loaded from './mindflow/cell/fno1d/data/fno1d_input.npy', and outputs
+ of the two models are compared to check if they match within the specified tolerance.
+ Expectation: The outputs of both models (with and without AMP) should match within the defined
+ relative and absolute tolerance values, verifying that AMP does not affect the accuracy.
+ """
+ config = load_yaml_config('./fno1d/configs/fno1d.yaml')
+ model_params = config["model"]
+ ckpt_path = './fno1d/ckpt/fno1d.ckpt'
+
+ model1 = FNO1D(in_channels=model_params["in_channels"],
+ out_channels=model_params["out_channels"],
+ n_modes=model_params["modes"],
+ resolutions=model_params["resolutions"],
+ hidden_channels=model_params["hidden_channels"],
+ n_layers=model_params["depths"],
+ projection_channels=4*model_params["hidden_channels"],
+ )
+
+ model2 = FNO1D(in_channels=model_params["in_channels"],
+ out_channels=model_params["out_channels"],
+ n_modes=model_params["modes"],
+ resolutions=model_params["resolutions"],
+ hidden_channels=model_params["hidden_channels"],
+ n_layers=model_params["depths"],
+ projection_channels=4*model_params["hidden_channels"],
+ fno_compute_dtype=mstype.float32,
+ )
+
+ params = load_checkpoint(ckpt_path)
+ load_param_into_net(model1, params)
+ load_param_into_net(model2, params)
+ input_data = np.load('./fno1d/data/fno1d_input.npy')
+ test_inputs = Tensor(input_data, mstype.float32)
+ output1 = model1(test_inputs)
+ output1 = output1.asnumpy()
+ output2 = model2(test_inputs)
+ output2 = output2.asnumpy()
+ validate_ans = compare_output(output1, output2, rtol=FP16_RTOL, atol=FP16_ATOL)
+ assert validate_ans, "The verification of FNO1D AMP accuracy is not successful."
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+def test_fno1d_grad_accuracy():
+ """
+ Feature: FNO1D gradient accuracy test
+ Description: Test the accuracy of the computed gradients for the FNO1D model. The model is initialized
+ with parameters from './mindflow/cell/fno1d/configs/fno1d.yaml' and weights are loaded
+ from the checkpoint located at './mindflow/cell/fno1d/ckpt/fno1d.ckpt'. The loss function used
+ is RelativeRMSELoss. The input data is loaded from './mindflow/cell/fno1d/data/fno1d_input.npy'
+ and the label is from './mindflow/cell/fno1d/data/fno1d_input_label.npy'. Gradients are computed
+ using MindSpore's value_and_grad and compared against the reference gradients stored in
+ './mindflow/cell/fno1d/data/fno1d_grads.npz'.
+ Expectation: The computed gradients should match the reference gradients within the specified relative and
+ absolute tolerance values, ensuring the gradient calculation is accurate.
+ """
+ config = load_yaml_config('./fno1d/configs/fno1d.yaml')
+ model_params = config["model"]
+ ckpt_path = './fno1d/ckpt/fno1d.ckpt'
+
+ model = FNO1D(in_channels=model_params["in_channels"],
+ out_channels=model_params["out_channels"],
+ n_modes=model_params["modes"],
+ resolutions=model_params["resolutions"],
+ hidden_channels=model_params["hidden_channels"],
+ n_layers=model_params["depths"],
+ projection_channels=4*model_params["hidden_channels"],
+ )
+
+ params = load_checkpoint(ckpt_path)
+ load_param_into_net(model, params)
+ input_data = np.load('./fno1d/data/fno1d_input.npy')
+ input_label = np.load('./fno1d/data/fno1d_input_label.npy')
+ test_inputs = Tensor(input_data, mstype.float32)
+ test_label = Tensor(input_label, mstype.float32)
+
+ problem = SteadyFlowWithLoss(
+ model, loss_fn=RelativeRMSELoss())
+
+ def forward_fn(data, label):
+ loss = problem.get_loss(data, label)
+ return loss
+
+ grad_fn = ops.value_and_grad(
+ forward_fn, None, model.trainable_params(), has_aux=False)
+
+ _, grads = grad_fn(test_inputs, test_label)
+ convert_grads = tuple(grad.asnumpy() for grad in grads)
+ with np.load('./fno1d/data/fno1d_grads.npz') as data:
+ output_target = tuple(data[key] for key in data.files)
+ validate_ans = compare_output(convert_grads, output_target, rtol=FP16_RTOL, atol=FP16_ATOL)
+ assert validate_ans, "The verification of FNO1D grad accuracy is not successful."
diff --git a/tests/sciops/test_fourier.py b/tests/sciops/test_fourier.py
new file mode 100644
index 0000000000000000000000000000000000000000..188e247206e784ac3910b38d4c19b5f27ac98513
--- /dev/null
+++ b/tests/sciops/test_fourier.py
@@ -0,0 +1,248 @@
+# ============================================================================
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Optimizers Test Case"""
+import os
+import random
+import sys
+from time import time as toc
+import pytest
+import numpy as np
+from scipy.fft import dct, dst
+import mindspore as ms
+from mindspore import set_seed, ops
+from mindflow import DFTn, IDFTn, RDFTn, IRDFTn, DCT, IDCT, DST, IDST
+
+PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
+sys.path.append(PROJECT_ROOT)
+
+# pylint: disable=wrong-import-position
+
+from common.cell import FP32_RTOL, FP16_RTOL, FP32_ATOL, FP16_ATOL
+from common.cell.utils import compare_output
+
+# pylint: enable=wrong-import-position
+
+set_seed(0)
+np.random.seed(0)
+random.seed(0)
+
+
+def gen_input(shape=(5, 6, 4, 8), rand_test=True):
+ ''' Generate random or deterministic tensor for input of the tests
+ '''
+ a = np.random.rand(*shape) + 1j * np.random.rand(*shape)
+ if not rand_test:
+ a = sum([np.arange(n).reshape([n] + [1] * j) for j, n in enumerate(shape[::-1])]) + 1j * \
+ sum([np.arange(n).reshape([n] + [1] * j) for j, n in enumerate(shape[::-1])])
+ ar, ai = (ms.Tensor(a.real, dtype=ms.float32), ms.Tensor(a.imag, dtype=ms.float32))
+ return a, ar, ai
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('device_target', ['CPU', 'Ascend'])
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+@pytest.mark.parametrize('ndim', [1, 2, 3])
+@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16])
+def test_rdft_accuracy(device_target, mode, ndim, compute_dtype):
+ """
+ Feature: Test RDFTn & IRDFTn accuracy
+ Description: Input random tensor, compare the results of RDFTn and IRDFTn with numpy results
+ Expectation: The output tensors should be equal within tolerance
+ """
+ ms.set_context(device_target=device_target, mode=mode)
+ a, ar, _ = gen_input()
+ shape = a.shape
+
+ b = np.fft.rfftn(a.real, s=a.shape[-ndim:], axes=range(-ndim, 0))
+ br, bi = RDFTn(shape[-ndim:], compute_dtype=compute_dtype)(ar)
+ cr = IRDFTn(shape[-ndim:], compute_dtype=compute_dtype)(br, bi)
+
+ rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10
+ atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20
+
+ assert compare_output(br.numpy(), b.real, rtol, atol)
+ assert compare_output(bi.numpy(), b.imag, rtol, atol)
+ assert compare_output(cr.numpy(), a.real, rtol, atol)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('device_target', ['CPU', 'Ascend'])
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+@pytest.mark.parametrize('ndim', [1, 2, 3])
+@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16])
+def test_dft_accuracy(device_target, mode, ndim, compute_dtype):
+ """
+ Feature: Test DFTn & IDFTn accuracy
+ Description: Input random tensor, compare the results of DFTn and IDFTn with numpy results
+ Expectation: The output tensors should be equal within tolerance
+ """
+ ms.set_context(device_target=device_target, mode=mode)
+ a, ar, ai = gen_input()
+ shape = a.shape
+
+ b = np.fft.fftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0))
+ br, bi = DFTn(shape[-ndim:], compute_dtype=compute_dtype)(ar, ai)
+ cr, ci = IDFTn(shape[-ndim:], compute_dtype=compute_dtype)(br, bi)
+
+ rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10
+ atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20
+
+ assert compare_output(br.numpy(), b.real, rtol, atol)
+ assert compare_output(bi.numpy(), b.imag, rtol, atol)
+ assert compare_output(cr.numpy(), a.real, rtol, atol)
+ assert compare_output(ci.numpy(), a.imag, rtol, atol)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('device_target', ['CPU', 'Ascend'])
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16])
+def test_dct_accuracy(device_target, mode, compute_dtype):
+ """
+ Feature: Test DCT & IDCT accuracy
+ Description: Input random tensor, compare the results of DCT and IDCT with numpy results
+ Expectation: The output tensors should be equal within tolerance
+ """
+ ms.set_context(device_target=device_target, mode=mode)
+ a, ar, _ = gen_input()
+ shape = a.shape
+
+ b = dct(a.real)
+ br = DCT(shape[-1:], compute_dtype=compute_dtype)(ar)
+ cr = IDCT(shape[-1:], compute_dtype=compute_dtype)(br)
+
+ rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10
+ atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20
+
+ assert compare_output(br.numpy(), b.real, rtol, atol)
+ assert compare_output(cr.numpy(), a.real, rtol, atol)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('device_target', ['CPU', 'Ascend'])
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16])
+def test_dst_accuracy(device_target, mode, compute_dtype):
+ """
+ Feature: Test DST & IDST accuracy
+ Description: Input random tensor, compare the results of DST and IDST with numpy results
+ Expectation: The output tensors should be equal within tolerance
+ """
+ ms.set_context(device_target=device_target, mode=mode)
+ a, ar, _ = gen_input()
+ shape = a.shape
+
+ b = dst(a.real)
+ br = DST(shape[-1:], compute_dtype=compute_dtype)(ar)
+ cr = IDST(shape[-1:], compute_dtype=compute_dtype)(br)
+
+ rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10
+ atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20
+
+ assert compare_output(br.numpy(), b.real, rtol, atol)
+ assert compare_output(cr.numpy(), a.real, rtol, atol)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('device_target', ['Ascend'])
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+@pytest.mark.parametrize('ndim', [1, 2, 3])
+def test_dft_speed(device_target, mode, ndim):
+ """
+ Feature: Test DFTn & IDFTn speed
+ Description: Input random tensor, clock the time of 10 runs of the
+ gradient function containing DFT & iDFT operators
+ Expectation: The average time of each run should be within 0.5s
+ """
+ # test dftn & idftn speed
+ ms.set_context(device_target=device_target, mode=mode)
+ a, ar, ai = gen_input(shape=(64, 128, 256))
+ shape = a.shape
+
+ warmup_steps = 10
+ timed_steps = 10
+
+ dft_cell = DFTn(shape[-ndim:])
+ idft_cell = IDFTn(shape[-ndim:])
+
+ def forward_fn(xr, xi):
+ br, bi = dft_cell(xr, xi)
+ cr, ci = idft_cell(br, bi)
+ return ops.sum(cr * cr + ci * ci)
+
+ grad_fn = ms.value_and_grad(forward_fn, grad_position=(0, 1))
+
+ # warmup run
+ for _ in range(warmup_steps):
+ _, (g1, g2) = grad_fn(ar, ai)
+ ar = ar - .1 * g1
+ ai = ai - .1 * g2
+
+ # timed run
+ tic = toc()
+ for _ in range(timed_steps):
+ _, (g1, g2) = grad_fn(ar, ai)
+ ar = ar - .1 * g1
+ ai = ai - .1 * g2
+
+ assert (toc() - tic) / timed_steps < 0.5
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('device_target', ['CPU', 'Ascend'])
+@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
+@pytest.mark.parametrize('ndim', [1, 2, 3])
+@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16])
+def test_dft_grad(device_target, mode, ndim, compute_dtype):
+ """
+ Feature: Test the correctness of DFTn & IDFTn grad calculation
+ Description: Input random tensor, compare the autograd results with theoretic solutions
+ Expectation: The autograd results should be equal to theoretic solutions
+ """
+ ms.set_context(device_target=device_target, mode=mode)
+ a, ar, ai = gen_input()
+ shape = a.shape
+
+ dft_cell = DFTn(shape[-ndim:], compute_dtype=compute_dtype)
+
+ def forward_fn(xr, xi):
+ yr, yi = dft_cell(xr, xi)
+ return ops.sum(yr * yr + yi * yi)
+
+ grad_fn = ms.value_and_grad(forward_fn, grad_position=(0, 1))
+ _, (g1, g2) = grad_fn(ar, ai)
+
+ # analytic solution of the gradient
+ b = np.fft.fftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0))
+ g = np.fft.ifftn(b, s=a.shape[-ndim:], axes=range(-ndim, 0)) * 2 * np.prod(a.shape[-ndim:])
+
+ rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10
+ atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 500 # grad func leads to larger error
+
+ assert compare_output(g1.numpy(), g.real, rtol, atol)
+ assert compare_output(g2.numpy(), g.imag, rtol, atol)