diff --git a/cmake/package.cmake b/cmake/package.cmake index c85bcb59a4f55b9880d6072ebb8f3efe792ed6ca..4830e79cc934ba1063f07d0480407c2128d12e5c 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -34,6 +34,7 @@ install( ${CMAKE_SOURCE_DIR}/mindscience/ccsrc ${CMAKE_SOURCE_DIR}/mindscience/common ${CMAKE_SOURCE_DIR}/mindscience/data + ${CMAKE_SOURCE_DIR}/mindscience/diffuser ${CMAKE_SOURCE_DIR}/mindscience/distributed ${CMAKE_SOURCE_DIR}/mindscience/e3nn ${CMAKE_SOURCE_DIR}/mindscience/gnn diff --git a/docs/architecture/mindscience-core.md b/docs/architecture/mindscience-core.md index 91c9eeae3e437e7053aeb8c11726e9365bbfc373..dd86d9b1daa67f27e84aabeeb06f94ece5aa4b3e 100644 --- a/docs/architecture/mindscience-core.md +++ b/docs/architecture/mindscience-core.md @@ -1,6 +1,8 @@ -# **重构设计: 分层架构与模块化方案** + -## **1. 顶层设计** +# 重构设计:分层架构与模块化方案 + +## 1. 顶层设计 ### 顶层设计目标 @@ -13,26 +15,26 @@ ### 设计原则 -1.开放扩展性 -- 提供清晰的接口规范(`ParallelStrategy`、`MemoryTechnique`、`PrecisionManager`、`GradientController`等) -- 支持注册自定义实现(`register_strategy()`,`register_memory_tech()`) +1. 开放扩展性 + - 提供清晰的接口规范(`ParallelStrategy`、`MemoryTechnique`、`PrecisionManager`、`GradientController`等) + - 支持注册自定义实现(`register_strategy()`,`register_memory_tech()`) -2.透明可观测 -- 内置详尽的性能分析工具 -- 所有优化操作可追溯、可禁用 +2. 透明可观测 + - 内置详尽的性能分析工具 + - 所有优化操作可追溯、可禁用 -3.渐进式抽象 -- 基础模式:直接配置驱动(YAML) -- 高级模式:细粒度API控制 -- 专家模式:直接访问底层组件 +3. 渐进式抽象 + - 基础模式:直接配置驱动(YAML) + - 高级模式:细粒度API控制 + - 专家模式:直接访问底层组件 -4.领域友好 -- 物理约束、数值方法等作为可插拔组件 -- 保留科学计算特有的控制流(如PINNS) +4. 领域友好 + - 物理约束、数值方法等作为可插拔组件 + - 保留科学计算特有的控制流(如PINNS) -5.性能优先 -- 所有抽象层保持零开销原则 -- 关键路径直接调用MindSpore原生接口 +5. 性能优先 + - 所有抽象层保持零开销原则 + - 关键路径直接调用MindSpore原生接口 ### 整体架构设计 @@ -41,11 +43,12 @@ MindSpore Science-Core的代码架构设计如下图: ![软件架构](../images/architecture.png) MindScience套件的核心框架为`MindSpore Science-Core`,`MindSpore Science-Core`之上为领域套件,包括`MindFlow`、`MindEarth`等。 + - **注意**:`gnn`和`e3nn`库包含了数据接口/网络等,内容较多,作为独立模块统一开发管理更为合理,能对齐业界方式。 - `MindSpore Science-Core`框架的底层为`speed`加速模块,包含了`sciops`和`distributed`。 - `sciops`目录包含两部分,区别于MindSpore的ops库,使用方式:`from mindscience import sciops`。 -- - 对接了`speed`模块.so文件的自定义算子接口,对应ccsrc目录,使用AscendC实现了`EvoformerAttn`、`FFT`等底层加速算子,不开源的代码开放.so文件。 -- - Python算子接口为基于Python开发的算子,如自动微分、`FFT`、`Irreps`、`SSM`、`FA`等算子。 + - 对接了`speed`模块.so文件的自定义算子接口,对应ccsrc目录,使用AscendC实现了`EvoformerAttn`、`FFT`等底层加速算子,不开源的代码开放.so文件。 + - Python算子接口为基于Python开发的算子,如自动微分、`FFT`、`Irreps`、`SSM`、`FA`等算子。 - `distributed`为并行加速模块,提供科学计算领域常用的并行接口,包括DP/OP/TP/PP等功能。 - `gnn`模块包含了图数据接口和常用GNN网络。 - `pde`模块包含了PINNs网络接口。 @@ -58,7 +61,7 @@ MindScience套件的核心框架为`MindSpore Science-Core`,`MindSpore Science `tests`为测试模块,包含框架的UT、ST用例,大部分为UT用例,由于门禁对用例时长有限制,端到端的ST用例主要在本地看护。 -## **2. 目录结构重构** +## 2. 目录结构重构 ```bash mindscience/ @@ -69,7 +72,7 @@ mindscience/ ├── setup.py # 版本依赖 ├── build.sh # 编译脚本 ├── docker/ # docker配置 -│ ├── Dockerfile +│ ├── Dockerfile ├── mindscience/ │ ├── common/ # 通用基础模块 │ │ ├── schedulers # Learning rate schedulers @@ -78,13 +81,13 @@ mindscience/ │ │ │ ├── primitives_2d/ # Rectangle, Disk, Triangle, Polygon │ │ │ ├── primitives_3d/ # Cuboid, Cylinder, Cone, Tetrahedron │ │ │ ├── primitives_nd/ # HyperCube, FixedPoint -│ │ ├── optimizers/ # 优化器(如LBFGS/PINNs优化器,AdaHessian二阶优化器) +│ │ ├── optimizers/ # 优化器(如LBFGS/PINNs优化器,AdaHessian二阶优化器) │ │ ├── metrics/ # 领域指标 │ │ │ ├── fid.py -│ │ │ ├── accuracy +│ │ │ ├── accuracy │ │ ├── losses/ # 损失函数(L2、物理约束损失) │ ├── data/ # 通用数据集模块 -│ │ ├── elec/ # 通用电磁数据接口 +│ │ ├── elec/ # 通用电磁数据接口 │ │ ├── earth/ # 通用气象数据接口 │ │ ├── flow/ # 通用流体数据接口 │ ├── ccsrc/ # 底层算子文件 @@ -95,53 +98,53 @@ mindscience/ │ │ ├── CMakeLists.txt # CMAKE │ ├── sciops/ # 通用算子库(加速算子、融合算子等) │ │ ├── einsum.py # einsum实现 -│ │ ├── differential.py # gradient, divergence -│ │ ├── evoformer_attention.py # evoformer实现 +│ │ ├── differential.py # gradient, divergence +│ │ ├── evoformer_attention.py # evoformer实现 │ │ ├── dft.py # DFT实现 │ ├── gnn/ # GNN模型库 │ │ ├── graph.py # 图数据接口 -│ │ ├── gat.py -│ │ ├── gcn.py +│ │ ├── gat.py +│ │ ├── gcn.py │ ├── e3nn/ # 等变计算库 │ │ ├── equivariant.py # 等变网络层 │ │ ├── irreps.py # 不可约表示(Irreps)定义 │ ├── pde/ # PINNS库 │ │ ├── pde_node.py # PDE定义 │ │ ├── pde_loss.py # PDE类实现 +│ ├── diffuser/ # 扩散模型 +│ │ ├── ddpm.py/ # DDPM扩散模型 +│ │ ├── ddim.py/ # DDIM扩散模型 │ ├── models/ # 跨领域可复用基础神经网络模块 │ │ ├── neural_operator/ # 通用神经算子库 -│ │ │ ├── fno.py -│ │ │ ├── kno.py -│ │ │ └── sno.py +│ │ │ ├── fno.py +│ │ │ ├── kno.py +│ │ │ └── sno.py │ │ ├── transformer/ # transformer模型 │ │ │ ├── attention.py # 自注意力网络层 -│ │ │ ├── ViT.py # ViT +│ │ │ ├── vit.py # VisionTransformer │ │ │ ├── DiT.py # DiT -│ │ ├── diffuser/ # 扩散模型 -│ │ │ ├── ddpm.py/ # DDPM扩散模型 -│ │ │ ├── ddim.py/ # DDIM扩散模型 │ │ ├── layers/ # 扩散模型 │ │ │ ├── mlp.py # MLP │ │ │ ├── kan.py # KAN │ │ │ ├── siren.py # siren -│ │ ├── PDEFormer/ +│ │ ├── PDEFormer/ │ │ ├── GraphCast/ -│ │ └── pangu/ +│ │ └── pangu/ │ ├── solvers/ # 通用求解器框架 │ │ ├── base_solver.py # Solver抽象基类 -│ │ ├── cbs.py -│ │ ├── cfd.py -│ │ ├── fdtd.py +│ │ ├── cbs.py +│ │ ├── cfd.py +│ │ ├── fdtd.py │ │ └── ... # 通用求解策略(如迭代求解、自适应步长) │ ├── utils/ # 统一工具API入口 │ │ ├── logging/ # 日志 -│ │ ├── config/ +│ │ ├── config/ │ │ ├── visualization/ # 画图 │ │ ├── io/ # IO │ │ ├── config.py # 配置管理 │ │ └── export.py # 模型导出工具(ONNX、MindIR) -│ ├── constants.py # 常量 -│ │ +│ ├── constants.py # 常量 +│ │ ├── mindflow/ # 计算流体动力学套件 │ ├── utils/ # 领域模型(如PINNs、FNO) │ ├── applications/ # 应用案例(圆柱绕流、湍流模拟) @@ -151,19 +154,18 @@ mindscience/ │ │ ├── mindelec/ # 计算电磁套件 │ ├── utils/ # 电磁模型 -│ ├── applications/ +│ ├── applications/ # 电磁应用案例 │ │ ├──src/ # 案例源文件 │ │ ├──train.py # 训练脚本 -│ └── ... -│ │ +│ └── ... ├── mindchemistry/ # 计算化学套件 │ ├── utils/ # 化学模型 -│ ├── applications/ +│ ├── applications/ │ │ ├──src/ # 案例源文件 │ │ ├──train.py # 训练脚本 -│ └── ... -│ -├── owner +│ └── ... +│ +├── owner │ ├── tests/ # 分层测试 │ ├── utils/ # 辅助函数 @@ -189,79 +191,87 @@ mindscience/ ``` +## 3. 核心模块设计 -## **3. 核心模块设计** +### (1) 通用API接口(mindscience/common) -### **(1) 通用API接口(mindscience/common)** - **功能**:提供通用接口,如学习率、数学运算、优化器、并行等接口。 - **示例**: + ```python - # mindscience/core/lr.py - def get_warmup_cosine_annealing_lr(): +# mindscience/core/lr.py +def get_warmup_cosine_annealing_lr(): ... - # mindscience/core/math.py - def get_grid_2d(resolution): - - class AdaHessian(): - +# mindscience/core/math.py +def get_grid_2d(resolution): ... +class AdaHessian(): + ... ``` -### **(2) 通用数据接口(mindscience/data)** +### (2) 通用数据接口(mindscience/data) + - **功能**:提供数据通用接口,包括流体/气象数据等接口。 - **示例**: - ```python - # mindscience/data/base.py - class BaseDataset: - def __init__(self, data_path, split_ratio=(0.7, 0.2, 0.1)): - self.data_path = data_path - self.split_ratio = split_ratio - - @abstractmethod - def load_data(self): - """加载原始数据(需子类实现)""" - - def preprocess(self): - """默认预处理(子类可覆盖)""" - - def split(self): - """按比例划分数据集""" - - # 领域套件实现示例(MindSponge) - class ProteinDataset(BaseDataset): - def load_data(self): - self.raw_data = read_pdb(self.data_path) - - def preprocess(self): - self.graph_data = pdb_to_graph(self.raw_data) - ``` -### 算子层(mindscience/sciops)** +```python +# mindscience/data/base.py +class BaseDataset: + def __init__(self, data_path, split_ratio=(0.7, 0.2, 0.1)): + self.data_path = data_path + self.split_ratio = split_ratio + + @abstractmethod + def load_data(self): + """加载原始数据(需子类实现)""" + + def preprocess(self): + """默认预处理(子类可覆盖)""" + + def split(self): + """按比例划分数据集""" + +# 领域套件实现示例(MindSponge) +class ProteinDataset(BaseDataset): + def load_data(self): + self.raw_data = read_pdb(self.data_path) + + def preprocess(self): + self.graph_data = pdb_to_graph(self.raw_data) +``` + +### 算子层(mindscience/sciops) + - **功能**:提供基于MindSpore开发的科学计算算子,包括底层融合算子和Python算子。 -#### 算子层(mindscience/sciops)** +#### 算子层(mindscience/sciops) + - **功能**:提供基于Python开发的科学计算算子。 - **示例**: - ```python - # mindscience/models/ops/fno.py - def hessian(): - ``` -#### **(3) 底层融合算子(mindscience/ccsrc)** +```python +# mindscience/models/ops/fno.py +def hessian(): + ... +``` + +#### (3) 底层融合算子(mindscience/ccsrc) + - **功能**:提供自定义算子的共享so文件,不开源,供给`sciops`层调用。主要涉及核心数值计算、物理仿真和科学计算基础算子等计算密集型、需要高度优化的算子,如DFT、矩阵分解、微分、信号处理等算子。MindSpore使用pybind11连接Python侧的接口和C++侧的接口,兼顾Python的简单便捷和C++的高性能,无缝继承numpy,减少数据拷贝。采用CPython解释器实现,会有GIL(Global Interpreter Lock,全局解释器锁),多线程并不能利用多核优势。可以通过换其他解释器,或者通过c++实现真正的多线程。 - **示例**: - ```C++ - #include "frontend/ops/ops.h" - #include - #include "include/core/utils/python_adapter.h" - #include "pipeline/jit/ps/parse/data_converter.h" - - namespace mindspore { - // namespace to support primitive operators - namespace prim { - ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name, bool use_signature) { + +```cpp +#include "frontend/ops/ops.h" +#include +#include "include/core/utils/python_adapter.h" +#include "pipeline/jit/ps/parse/data_converter.h" + +namespace mindspore { +// namespace to support primitive operators +namespace prim { +ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name, bool use_signature) { py::object obj = python_adapter::GetPyFn(module_name, op_name); ValuePtr node = nullptr; bool succ = parse::ConvertData(obj, &node, use_signature); @@ -269,31 +279,31 @@ mindscience/ MS_LOG(INTERNAL_EXCEPTION) << "Get Python op " << op_name << " from " << module_name << " fail."; } return node; - } - } // namespace prim - } // namespace mindspore - - ``` +} +} // namespace prim +} // namespace mindspore +``` +### 通用PINNS库(mindscience/pde) -### 通用PINNS库(mindscience/pde)** - **功能**:封装可复用的PINNS接口,供各个领域使用。 - **示例**: - ```python - # mindscience/models/transformer/attention.py - class PdeNode: - """PDE节点定义""" - def __init__(self, use_flash=False): - ... - - def construct(self, q, k, v): - ... - # mindscience/models/pde/diffusion.py - class PDEWithLoss: - def __init__(self, use_flash=False): - ... - +```python +# mindscience/models/transformer/attention.py +class PdeNode: + """PDE节点定义""" + def __init__(self, use_flash=False): + ... + + def construct(self, q, k, v): + ... + +# mindscience/models/pde/diffusion.py +class PDEWithLoss: + def __init__(self, use_flash=False): + ... + def pde(self): """ Governing equation based on sympy, abstract method. @@ -309,12 +319,13 @@ mindscience/ def parse_node(self, formula_nodes, inputs=None, norm=None): ... - ``` +``` +#### 并行加速库(mindscience/distributed) -#### 并行加速库(mindscience/distributed)** - **功能**:并行训练接口。 - **示例**: + ```python # mindscience/models/transformer/attention.py class Op: @@ -326,22 +337,24 @@ mindscience/ class Tp: def __init__(self,): ... - ``` -### **(4) 神经网络基础层(mindscience/models)** +### (4) 神经网络基础层(mindscience/models) + 在`sciops`层之上,封装可复用的网络模块,继承自`nn.Cell`,供各个领域使用。 -#### 通用Transformer库(mindscience/models/transformer)** +#### 通用Transformer库(mindscience/models/transformer) + - **功能**:封装可复用的Transformer模块,供各个领域使用。 - **示例**: + ```python # mindscience/models/gnn/Gat.py class Transformer: def __init__(self, use_flash=False): ... - + def construct(self, q, k, v): ... @@ -349,49 +362,55 @@ mindscience/ class Attention: def __init__(self, use_flash=False): ... - + def construct(self, q, k, v): ... ``` -#### 通用diffusion库(mindscience/models/diffuser)** +#### 通用diffusion库(mindscience/models/diffuser) + - **功能**:封装可复用的diffusion模块,供各个领域使用。 - **示例**: + ```python # mindscience/models/diffusion/ddpm.py class DDPM: def __init__(self, use_flash=False): ... - + def add_noise(self, q, k, v): ... ``` -#### 通用网络层库(mindscience/models/layers)** +#### 通用网络层库(mindscience/models/layers) + - **功能**:封装可复用的diffusion模块,供各个领域使用。 - **示例**: + ```python # mindscience/models/diffusion/ddpm.py class MLP: def __init__(self,): ... - + def add_noise(self, x): ... ``` -### 通用等变计算库(mindscience/e3nn)** +### 通用等变计算库(mindscience/e3nn) + - **功能**:封装可复用的等变计算接口,供各个领域使用。 - **示例**: + ```python # mindscience/models/e3nn/irreps.py class Irreps: """不可约表示(兼容e3nn语法)""" def __init__(self, irreps_str): self.irreps = self._parse(irreps_str) - + def _parse(self, s): # 解析字符串如 "5x0e + 3x1o" ... @@ -407,7 +426,7 @@ mindscience/ nn.Dense(64, irreps_out.scalar_dim) ]) self.gate_nn = nn.Dense(irreps_in.vector_dim, irreps_out.gate_dim) - + def construct(self, x_scalar, x_vector): gate = self.gate_nn(x_vector) scalar_out = self.scalar_nn(x_scalar) * gate @@ -425,8 +444,10 @@ mindscience/ ``` ### 通用GNN库(mindscience/gnn)** + - **功能**:封装可复用的GNN模块,供各个领域使用。 - **示例**: + ```python class Graph: def __init__(self, nodes, edges, node_feat, edge_feat): @@ -434,7 +455,7 @@ mindscience/ self.edges = edges # 边列表 (, dst) self.node_feat = node_feat # 节点特征张量 self.edge_feat = edge_feat # 边特征张量 - + def to_mindspore_tensor(self): """转换为MindSpore张量格式""" @@ -444,10 +465,11 @@ mindscience/ """从图中采样k-hop邻居子图""" ``` +### (5) 求解器框架(mindscience/solvers) -### **(5) 求解器框架(mindscience/solvers)** - **功能**:定义通用求解器框架。 - **示例**: + ```python # core/solvers/base_solver.py class BaseSolver(nn.Cell): @@ -455,109 +477,124 @@ mindscience/ self.model = model self.optimizer = optimizer self.loss_fn = loss_fn - + def train_step(self, data): # 标准训练步骤(前向、损失、反向) ... - + def solve(self, inputs): # 推理接口 return self.model(inputs) ``` -### **(6) 工具接口(mindscience/utils)** -- **功能**:定义通用求解器框架。 +### (6) 工具接口(mindscience/utils) + +- **功能**:提供通用工具和辅助函数。 - **示例**: - ```python - # mindscience/utils/logging.py - class Logger(nn.Cell): - def __init__(self, ): - ... - ``` +```python +# mindscience/utils/logging.py +class Logger(nn.Cell): + def __init__(self, ): + ... +``` --- -## **4. 领域套件设计(以MindFlow为例)** +## 4. 领域套件设计(以MindFlow为例) + +### (1) 领域工具(mindflow/utils) -### **(1) 领域工具(mindflow/utils)** - **职责**:领域专用工具(如流场可视化、边界条件生成)。 - **示例**: - ```python - # mindflow/utils/visualization.py - def plot_velocity_field(data, pred): - """绘制速度场对比图""" - ... - ``` -### **(2) 应用案例(mindflow/applications)** +```python +# mindflow/utils/visualization.py +def plot_velocity_field(data, pred): + """绘制速度场对比图""" + ... +``` + +### (2) 应用案例(mindflow/applications) + - **职责**:提供端到端案例(数据、训练、可视化)。 - **示例**: - ```python - # mindflow/applications/cylinder_flow/train.py - def train_cylinder_flow(): - # 加载数据 - data = load_data("cylinder_dataset.h5") - # 初始化模型(使用core模块组件) - model = PINNs(MLP(input_dims=3, hidden_dims=[128, 128])) - solvers = BaseSolver(model, Adam(model.params()), MSE()) - # 训练与评估 - solvers.train(data, epochs=1000) - solvers.export("cylinder_model.mindir") - ``` + +```python +# mindflow/applications/cylinder_flow/train.py +def train_cylinder_flow(): + # 加载数据 + data = load_data("cylinder_dataset.h5") + # 初始化模型(使用core模块组件) + model = PINNs(MLP(input_dims=3, hidden_dims=[128, 128])) + solvers = BaseSolver(model, Adam(model.params()), MSE()) + # 训练与评估 + solvers.train(data, epochs=1000) + solvers.export("cylinder_model.mindir") +``` --- -## **5. 依赖管理与构建优化** -### **(1) 模块化安装** +## 5. 依赖管理与构建优化 + +### (1) 模块化安装 + - **setup.py 配置**: - ```python - # 支持按需安装 - extras_require={ - "mindflow": ["matplotlib", "h5py"], - "mindsponge": ["biopython"], - "all": ["matplotlib", "h5py", "biopython"] - } - ``` -### **(2) 动态导入机制** +```python +# 支持按需安装 +extras_require={ + "mindflow": ["matplotlib", "h5py"], + "mindsponge": ["biopython"], + "all": ["matplotlib", "h5py", "biopython"] +} +``` + +### (2) 动态导入机制 + - **延迟加载领域模块**: - ```python - # core/api/__init__.py - def get_solver(solver_type): - if solver_type == "mindflow": - from mindflow.solvers import FlowSolver - return FlowSolver - elif solver_type == "mindsponge": - from mindsponge.solvers import MDsolver - return MDsolver - ``` + +```python +# core/api/__init__.py +def get_solver(solver_type): + if solver_type == "mindflow": + from mindflow.solvers import FlowSolver + return FlowSolver + elif solver_type == "mindsponge": + from mindsponge.solvers import MDsolver + return MDsolver +``` --- -## **6. 测试与文档策略** +## 6. 测试与文档策略 + +### (1) 门禁测试 -### **(1) 门禁测试** gitee门禁上需要大修改,要联系门禁组增加需求。 + - **基础模块测试**:覆盖core下的所有算子与组件。 - ```python - # tests/core/test_attention.py - def test_multihead_attention(): - attn = MultiHeadAttention(use_flash=True) - output = attn(q, k, v) - assert output.shape == expected_shape - ``` + +```python +# tests/core/test_attention.py +def test_multihead_attention(): + attn = MultiHeadAttention(use_flash=True) + output = attn(q, k, v) + assert output.shape == expected_shape +``` + 测试目录 - **领域测试**:验证领域模型与案例。 - ```python - # tests/mindflow/test_pinn.py - def test_pinn_convergence(): - loss = train_pinn(...) - assert loss < 1e-3 - ``` -### **(2) 文档统一化** +```python +# tests/mindflow/test_pinn.py +def test_pinn_convergence(): + loss = train_pinn(...) + assert loss < 1e-3 +``` + +### (2) 文档统一化 当前每个领域套件都有自己的一个页面,展示案例和API,重构后每个套件只展示领域案例,框架的API统一展示。 @@ -566,52 +603,57 @@ gitee门禁上需要大修改,要联系门禁组增加需求。 --- -## **7. 性能优化与生态集成** +## 7. 性能优化与生态集成 + +### (1) MindSpore特性利用 -### **(1) MindSpore特性利用** - **图算融合**:通过`nn.Cell`、`@ms_function`、`@lazy_inline`、`@jit`、`@jit_class`优化计算图。 - **原生并行**:在底层集成分布式策略,领域套件无需关注细节。 - ```python - # core/parallel/config.py - def auto_parallel(model, data_parallel=4, model_parallel=2): - model = auto_parallel(model) - ... - ``` -### **(2) 软硬件适配** +```python +# core/parallel/config.py +def auto_parallel(model, data_parallel=4, model_parallel=2): + model = auto_parallel(model) + ... +``` + +### (2) 软硬件适配 -主要针对Ascend910B+Mindspore2.6.0版本 +主要针对Ascend+Mindspore2.6.0版本 - **多后端支持**:在core层封装CPU/Ascend算子差异。 - ```python - # mindscience/models/ops/math_ops/fft.py - def fft2d(x): - if get_device() == "CPU": - return cufft.fft2d(x) - else: - return ascend_fft.fft2d(x) - ``` + +```python +# mindscience/models/ops/math_ops/fft.py +def fft2d(x): + if get_device() == "CPU": + return cufft.fft2d(x) + else: + return ascend_fft.fft2d(x) +``` --- -## **8. 实施步骤** +## 8. 实施步骤 + 由于`MindEnergy`套件开源的需求急迫,因此先按照重构后的方式开发`MindEnergy`套件。 + 1. **代码分层拆分** - - 将现有代码按功能拆分为`mindscience`核心模块和领域套件。 + - 将现有代码按功能拆分为`mindscience`核心模块和领域套件。 2. **接口标准化** - - 定义通用基类(如`BaseSolver`)并重构现有模型。 + - 定义通用基类(如`BaseSolver`)并重构现有模型。 3. **依赖解耦** - - 确保领域套件仅依赖`mindscience`,无横向依赖。 + - 确保领域套件仅依赖`mindscience`,无横向依赖。 4. **测试迁移** - - 将原有测试按模块归属迁移到对应的`tests/modules`和`tests/mindflow`等。 + - 将原有测试按模块归属迁移到对应的`tests/modules`和`tests/mindflow`等。 5. **CI/CD适配** - - 更新CI流程,分模块运行测试与构建。 + - 更新CI流程,分模块运行测试与构建。 6. **文档迁移与更新** - - 重构文档结构,突出分层设计。 + - 重构文档结构,突出分层设计。 -## **9. 负责人** +## 9. 负责人 -### 责任田负责人: +### 责任田负责人 责任田负责人负责模块的开发合入,相关代码的重构和合入优先找特性的第一责任人开发,其次找对应领域的开发人员。 @@ -639,5 +681,6 @@ committer:刘红升、郭伯强、张毅、龚玥 --- -## **总结** -通过以上设计,MindScience将实现**高内聚、低耦合**的架构,各领域套件可快速复用核心能力,同时专注于领域创新。此方案平衡了灵活性与统一性,为后续扩展(如新增MindEnergy套件)提供了清晰路径。 \ No newline at end of file +## 总结 + +通过以上设计,MindScience将实现**高内聚、低耦合**的架构,各领域套件可快速复用核心能力,同时专注于领域创新。此方案平衡了灵活性与统一性,为后续扩展(如新增MindEnergy套件)提供了清晰路径。 diff --git a/mindscience/__init__.py b/mindscience/__init__.py index 056989625e83fab699f439350602153017ff6b2c..341e30e6c4c66c3375d9805a9064d8fa1bfcabee 100644 --- a/mindscience/__init__.py +++ b/mindscience/__init__.py @@ -20,6 +20,7 @@ import time from .common import * from .data import * +from .diffuser import * from .distributed import * from .e3nn import * from .gnn import * @@ -74,6 +75,7 @@ _mindspore_version_check() __all__ = [] __all__.extend(common.__all__) __all__.extend(data.__all__) +__all__.extend(diffuser.__all__) __all__.extend(distributed.__all__) __all__.extend(e3nn.__all__) __all__.extend(gnn.__all__) diff --git a/mindscience/models/diffuser/__init__.py b/mindscience/diffuser/__init__.py similarity index 88% rename from mindscience/models/diffuser/__init__.py rename to mindscience/diffuser/__init__.py index b87f7c9375d8b491faa309ab978d17ada93e8d5d..1d1296671630ff6a4786f5b3657a7af84a498ace 100644 --- a/mindscience/models/diffuser/__init__.py +++ b/mindscience/diffuser/__init__.py @@ -18,4 +18,5 @@ init from .diffusion import DDPMPipeline, DiffusionScheduler, DDPMScheduler, DDIMPipeline, DDIMScheduler, DiffusionTrainer from .diffusion_transformer import DiffusionTransformer, ConditionDiffusionTransformer -__all__ = ["DiffusionScheduler", "DDPMPipeline", "DDPMScheduler", "DDIMPipeline", "DDIMScheduler", "DiffusionTrainer", "DiffusionTransformer", "ConditionDiffusionTransformer"] \ No newline at end of file +__all__ = ["DiffusionScheduler", "DDPMPipeline", "DDPMScheduler", "DDIMPipeline", + "DDIMScheduler", "DiffusionTrainer", "DiffusionTransformer", "ConditionDiffusionTransformer"] diff --git a/mindscience/models/diffuser/diffusion.py b/mindscience/diffuser/diffusion.py similarity index 91% rename from mindscience/models/diffuser/diffusion.py rename to mindscience/diffuser/diffusion.py index a407ca81f150df8d90b3cf1f3253b12f8bde16a1..2f1bee7c48ac8523ca17ae92c67cea6884e82c81 100644 --- a/mindscience/models/diffuser/diffusion.py +++ b/mindscience/diffuser/diffusion.py @@ -28,7 +28,10 @@ def extract(a, t, x_shape): return out.reshape(b, *((1,) * (len(x_shape) - 1))) -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine"): +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine"): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -180,15 +183,19 @@ class DiffusionScheduler: posterior_variance = self.betas * \ (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / + # beta_t) posterior_variance = np.clip(posterior_variance, 1e-20, None) self.posterior_variance = Tensor( posterior_variance, dtype=compute_dtype) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = Tensor(np.log( - posterior_variance), dtype=compute_dtype) # Tensor(np.log(posterior_variance)) - # See formula (7) from `Denoising Diffusion Probabilistic Models `_ + # below: log calculation clipped because the posterior variance is 0 at + # the beginning of the diffusion chain + self.posterior_log_variance_clipped = Tensor( + np.log(posterior_variance), + dtype=compute_dtype) # Tensor(np.log(posterior_variance)) + # See formula (7) from `Denoising Diffusion Probabilistic Models + # `_ self.posterior_mean_coef1 = Tensor( self.betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod), dtype=compute_dtype) self.posterior_mean_coef2 = Tensor( @@ -196,7 +203,10 @@ class DiffusionScheduler: self.num_inference_steps = None self.dynamic_thresholding_ratio = dynamic_thresholding_ratio - def _init_betas(self, beta_schedule="squaredcos_cap_v2", rescale_betas_zero_snr=False): + def _init_betas( + self, + beta_schedule="squaredcos_cap_v2", + rescale_betas_zero_snr=False): """init noise beta schedule Args: @@ -218,7 +228,9 @@ class DiffusionScheduler: elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. betas = np.linspace( - self.beta_start**0.5, self.beta_end**0.5, self.num_train_timesteps) ** 2 + self.beta_start**0.5, + self.beta_end**0.5, + self.num_train_timesteps) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule betas = betas_for_alpha_bar(self.num_train_timesteps) @@ -251,8 +263,7 @@ class DiffusionScheduler: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `num_train_timesteps`:" f" {self.num_train_timesteps} as the diffusion model trained with this scheduler can only handle" - f" maximal {self.num_train_timesteps} timesteps." - ) + f" maximal {self.num_train_timesteps} timesteps.") self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of if self.timestep_spacing == "linspace": @@ -266,15 +277,17 @@ class DiffusionScheduler: elif self.timestep_spacing == "leading": step_ratio = self.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 + # casting to int to avoid issues when num_inference_step is power + # of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int32) elif self.timestep_spacing == "trailing": step_ratio = self.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.round( - np.arange(self.num_train_timesteps, 0, -step_ratio)).astype(np.int32) + # casting to int to avoid issues when num_inference_step is power + # of 3 + timesteps = np.round(np.arange( + self.num_train_timesteps, 0, -step_ratio)).astype(np.int32) timesteps -= 1 else: raise ValueError( @@ -309,7 +322,11 @@ class DiffusionScheduler: return sample - def _pred_origin_sample(self, model_output: Tensor, sample: Tensor, timestep: Tensor): + def _pred_origin_sample( + self, + model_output: Tensor, + sample: Tensor, + timestep: Tensor): """ Predict x_0 with x_t. @@ -329,23 +346,26 @@ class DiffusionScheduler: elif self.prediction_type == "sample": pred_original_sample = model_output elif self.prediction_type == "v_prediction": - pred_original_sample = extract(self.sqrt_alphas_cumprod, timestep, sample.shape)*sample - \ - extract(self.sqrt_one_minus_alphas_cumprod, - timestep, sample.shape)*model_output + pred_original_sample = extract( + self.sqrt_alphas_cumprod, timestep, sample.shape) * sample - extract( + self.sqrt_one_minus_alphas_cumprod, timestep, sample.shape) * model_output else: raise ValueError( f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction`" - ) + " `v_prediction`") # 2. Clip or threshold "predicted x_0" if self.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.clip_sample: - pred_original_sample = pred_original_sample.clamp(-self.clip_sample_range, - self.clip_sample_range) + pred_original_sample = pred_original_sample.clamp( + -self.clip_sample_range, self.clip_sample_range) return pred_original_sample - def add_noise(self, original_samples: Tensor, noise: Tensor, timesteps: Tensor): + def add_noise( + self, + original_samples: Tensor, + noise: Tensor, + timesteps: Tensor): """ Diffusion add noise process. @@ -357,8 +377,17 @@ class DiffusionScheduler: Returns: Tensor, the noised sample of the next step. """ - return (extract(self.sqrt_alphas_cumprod, timesteps, original_samples.shape)*original_samples + - extract(self.sqrt_one_minus_alphas_cumprod, timesteps, original_samples.shape)*noise) + return ( + extract( + self.sqrt_alphas_cumprod, + timesteps, + original_samples.shape) * + original_samples + + extract( + self.sqrt_one_minus_alphas_cumprod, + timesteps, + original_samples.shape) * + noise) def step(self, model_output: Tensor, sample: Tensor, timestep: Tensor): """ @@ -479,7 +508,8 @@ class DDPMScheduler(DiffusionScheduler): # hacks - were probably added for training stability if self.variance_type == "fixed_small": pass - # for rl-diffuser `Planning with Diffusion for Flexible Behavior Synthesis `_ + # for rl-diffuser `Planning with Diffusion for Flexible Behavior + # Synthesis `_ elif self.variance_type == "fixed_small_log": variance = ops.log(variance) variance = ops.exp(0.5 * variance) @@ -534,11 +564,19 @@ class DDPMScheduler(DiffusionScheduler): pred_original_sample = self._pred_origin_sample( model_output, sample, timestep) # 2. Compute predicted previous sample µ_t - # See formula (7) from `Denoising Diffusion Probabilistic Models `_ + # See formula (7) from `Denoising Diffusion Probabilistic Models + # `_ pred_prev_sample = ( - extract(self.posterior_mean_coef1, timestep, sample.shape)*pred_original_sample + - extract(self.posterior_mean_coef2, timestep, sample.shape)*sample - ) + extract( + self.posterior_mean_coef1, + timestep, + sample.shape) * + pred_original_sample + + extract( + self.posterior_mean_coef2, + timestep, + sample.shape) * + sample) # 3. Add noise v = self._get_variance(sample, timestep, predicted_variance) @@ -660,7 +698,11 @@ class DDIMScheduler(DiffusionScheduler): return variance - def _pred_epsilon(self, model_output: Tensor, sample: Tensor, timestep: Tensor): + def _pred_epsilon( + self, + model_output: Tensor, + sample: Tensor, + timestep: Tensor): """ Predict epsilon. @@ -748,17 +790,21 @@ class DDIMScheduler(DiffusionScheduler): std_dev_t = (eta * ops.sqrt(variance)).astype(dtype) if use_clipped_model_output: - # the pred_epsilon is always re-derived from the clipped x_0 in Glide + # the pred_epsilon is always re-derived from the clipped x_0 in + # Glide pred_epsilon = ( (sample - (alpha_prod_t ** (0.5)).astype(dtype) * pred_original_sample) / beta_prod_t ** (0.5) ).astype(dtype) - # 5. compute "direction pointing to x_t" of formula (12) from `Denoising Diffusion Implicit Models `_ - pred_sample_direction = ( - (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5)).reshape(batch_size, 1, 1) * pred_epsilon + # 5. compute "direction pointing to x_t" of formula (12) from + # `Denoising Diffusion Implicit Models + # `_ + pred_sample_direction = ((1 - alpha_prod_t_prev - std_dev_t**2) + ** (0.5)).reshape(batch_size, 1, 1) * pred_epsilon - # 6. compute x_t without "random noise" of formula (12) from `Denoising Diffusion Implicit Models `_ + # 6. compute x_t without "random noise" of formula (12) from `Denoising + # Diffusion Implicit Models `_ coef = ops.sqrt(alpha_prod_t_prev).reshape(batch_size, 1, 1) prev_sample = coef * pred_original_sample + pred_sample_direction if eta > 0: @@ -826,7 +872,14 @@ class DiffusionPipeline: """ - def __init__(self, model, scheduler, batch_size, seq_len, num_inference_steps=1000, compute_dtype=mstype.float32): + def __init__( + self, + model, + scheduler, + batch_size, + seq_len, + num_inference_steps=1000, + compute_dtype=mstype.float32): self.model = model self.scheduler = scheduler self.seq_len = seq_len @@ -863,8 +916,12 @@ class DiffusionPipeline: Returns: Tensor, Predicted original samples. """ - sample = Tensor(np.random.randn(self.batch_size, self.seq_len, - self.model.in_channels), dtype=self.compute_dtype) + sample = Tensor( + np.random.randn( + self.batch_size, + self.seq_len, + self.model.in_channels), + dtype=self.compute_dtype) if condition is not None: condition = condition.reshape(self.batch_size, -1) @@ -932,11 +989,19 @@ class DDPMPipeline(DiffusionPipeline): """ # pylint: disable=W0235 - def __init__(self, model, scheduler, batch_size, seq_len, num_inference_steps=1000, compute_dtype=mstype.float32): + def __init__( + self, + model, + scheduler, + batch_size, + seq_len, + num_inference_steps=1000, + compute_dtype=mstype.float32): if not isinstance(scheduler, DDPMScheduler): raise TypeError('scheduler type must be DDPMScheduler') if num_inference_steps != scheduler.num_train_timesteps: - raise ValueError('num_inference_steps must equal to num_train_timesteps') + raise ValueError( + 'num_inference_steps must equal to num_train_timesteps') super().__init__(model, scheduler, batch_size, seq_len, num_inference_steps, compute_dtype) @@ -1001,17 +1066,34 @@ class DDIMPipeline(DiffusionPipeline): """ # pylint: disable=W0235 - def __init__(self, model, scheduler, batch_size, seq_len, num_inference_steps=1000, compute_dtype=mstype.float32): + def __init__( + self, + model, + scheduler, + batch_size, + seq_len, + num_inference_steps=1000, + compute_dtype=mstype.float32): if not isinstance(scheduler, DDIMScheduler): raise TypeError('scheduler type must be DDIMScheduler') super().__init__(model, scheduler, batch_size, seq_len, num_inference_steps, compute_dtype) # pylint: disable=W0221 - def _sample_step(self, sample, condition, timesteps, eta, use_clipped_model_output): + def _sample_step( + self, + sample, + condition, + timesteps, + eta, + use_clipped_model_output): model_output = self._pred_noise(sample, condition, timesteps) - sample = self.scheduler.step(model_output=model_output, sample=sample, timestep=timesteps, - eta=eta, use_clipped_model_output=use_clipped_model_output) + sample = self.scheduler.step( + model_output=model_output, + sample=sample, + timestep=timesteps, + eta=eta, + use_clipped_model_output=use_clipped_model_output) return sample def __call__(self, condition=None, eta=0., use_clipped_model_output=False): @@ -1034,15 +1116,23 @@ class DDIMPipeline(DiffusionPipeline): """ if not 0 <= eta <= 1: raise ValueError('eta must in range [0, 1]') - sample = Tensor(np.random.randn(self.batch_size, self.seq_len, - self.model.in_channels), dtype=self.compute_dtype) + sample = Tensor( + np.random.randn( + self.batch_size, + self.seq_len, + self.model.in_channels), + dtype=self.compute_dtype) if condition is not None: condition = condition.reshape(self.batch_size, -1) for t in self.scheduler.num_timesteps: batched_times = ops.ones((self.batch_size,), mstype.int32) * int(t) sample = self._sample_step( - sample, condition, batched_times, eta, use_clipped_model_output) + sample, + condition, + batched_times, + eta, + use_clipped_model_output) return sample @@ -1134,7 +1224,12 @@ class DiffusionTrainer: else: raise ValueError(f'invalid loss type {self.loss_type}') - def get_loss(self, original_samples: Tensor, noise: Tensor, timesteps: Tensor, condition: Tensor = None): + def get_loss( + self, + original_samples: Tensor, + noise: Tensor, + timesteps: Tensor, + condition: Tensor = None): r""" Calculate the forward loss of diffusion process. @@ -1160,9 +1255,17 @@ class DiffusionTrainer: elif self.objective == 'pred_x0': target = original_samples elif self.objective == 'pred_v': - target = (extract(self.scheduler.sqrt_alphas_cumprod, timesteps, original_samples.shape)*noise - - extract(self.scheduler.sqrt_one_minus_alphas_cumprod, timesteps, - original_samples.shape)*original_samples) + target = ( + extract( + self.scheduler.sqrt_alphas_cumprod, + timesteps, + original_samples.shape) * + noise - + extract( + self.scheduler.sqrt_one_minus_alphas_cumprod, + timesteps, + original_samples.shape) * + original_samples) else: target = noise diff --git a/mindscience/models/diffuser/diffusion_transformer.py b/mindscience/diffuser/diffusion_transformer.py similarity index 99% rename from mindscience/models/diffuser/diffusion_transformer.py rename to mindscience/diffuser/diffusion_transformer.py index 4d18af80d10d35e5be1f86f48451a7b784ad198a..c6baea62a775ed05c669d246b5c4c6bf83e2fd0b 100644 --- a/mindscience/models/diffuser/diffusion_transformer.py +++ b/mindscience/diffuser/diffusion_transformer.py @@ -19,7 +19,7 @@ import math import numpy as np from mindspore import nn, ops, Tensor from mindspore import dtype as mstype -from ..transformer import TransformerBlock +from ..models import TransformerBlock class Mlp(nn.Cell): diff --git a/mindscience/models/GraphCast/__init__.py b/mindscience/models/GraphCast/__init__.py index ccc0beecde732fb11a0215a83d5127b3c465319b..1f46ab5243cf79a3c3c32b8d6375d1941efa5388 100644 --- a/mindscience/models/GraphCast/__init__.py +++ b/mindscience/models/GraphCast/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""init""" +"""GraphCast Models Package + +This package contains implementations of GraphCast models for weather forecasting. +GraphCast is a machine learning model for weather forecasting that uses graph neural networks. +""" from .graphcastnet import GraphCastNet __all__ = ['GraphCastNet'] diff --git a/mindscience/models/__init__.py b/mindscience/models/__init__.py index e33e2ab90a363210a6b8313ba23b98088e46bc62..2cc514aa99bf2e676f033151dff32fb1a9228b98 100644 --- a/mindscience/models/__init__.py +++ b/mindscience/models/__init__.py @@ -13,16 +13,18 @@ # limitations under the License. # ============================================================================ """ -init +Models Package + +This package contains various neural network models and related components. +It includes implementations of popular architectures such as Vision Transformer, +Graph Neural Networks, and other domain-specific models. """ -from .diffuser import * from .GraphCast import * from .layers import * from .neural_operator import * from .transformer import * __all__ = [] -__all__.extend(diffuser.__all__) __all__.extend(GraphCast.__all__) __all__.extend(layers.__all__) __all__.extend(neural_operator.__all__) diff --git a/mindscience/models/layers/__init__.py b/mindscience/models/layers/__init__.py index 929173e0fd7af51a269eb6311a6e5716d00f4700..203fc4563e15dbb7fe7b20792aafec41f3a563d5 100644 --- a/mindscience/models/layers/__init__.py +++ b/mindscience/models/layers/__init__.py @@ -13,10 +13,16 @@ # limitations under the License. # ============================================================================ """ -init +Layers Package + +This package contains various neural network layers and building blocks +that are used across different models in the MindScience toolkit. +It includes activation functions, basic blocks, and specialized layers +like UNet2D. """ from .activation import get_activation from .basic_block import LinearBlock, ResBlock, InputScale, FCSequential, MultiScaleFCSequential, DropPath from .unet2d import UNet2D -__all__ = ["get_activation", "LinearBlock", "ResBlock", "InputScale", "FCSequential", "MultiScaleFCSequential", "DropPath", "UNet2D"] \ No newline at end of file +__all__ = ["get_activation", "LinearBlock", "ResBlock", "InputScale", "FCSequential", + "MultiScaleFCSequential", "DropPath", "UNet2D"] diff --git a/mindscience/models/neural_operator/afno2d.py b/mindscience/models/neural_operator/afno2d.py index a10cece18886a09bea674d934c271aec97a7731d..b3e78739be2b807cf10118462d875176abeb6dca 100644 --- a/mindscience/models/neural_operator/afno2d.py +++ b/mindscience/models/neural_operator/afno2d.py @@ -77,12 +77,20 @@ class Mlp(nn.Cell): dropout_rate=1.0, compute_dtype=mstype.float16): super(Mlp, self).__init__() - self.fc1 = nn.Dense(embed_dims, embed_dims * mlp_ratio, - weight_init=initializer(Normal(sigma=0.02), shape=(embed_dims * mlp_ratio, embed_dims)), - ).to_float(compute_dtype) - self.fc2 = nn.Dense(embed_dims * mlp_ratio, embed_dims, - weight_init=initializer(Normal(sigma=0.02), shape=(embed_dims, embed_dims * mlp_ratio)), - ).to_float(compute_dtype) + self.fc1 = nn.Dense( + embed_dims, embed_dims * mlp_ratio, + weight_init=initializer( + Normal(sigma=0.02), + shape=(embed_dims * mlp_ratio, embed_dims) + ), + ).to_float(compute_dtype) + self.fc2 = nn.Dense( + embed_dims * mlp_ratio, embed_dims, + weight_init=initializer( + Normal(sigma=0.02), + shape=(embed_dims, embed_dims * mlp_ratio) + ), + ).to_float(compute_dtype) self.act_fn = nn.GELU() self.dropout = nn.Dropout(dropout_rate) @@ -130,12 +138,14 @@ class AFNOBlock(nn.Cell): self.ffn_norm = nn.LayerNorm([embed_dims], epsilon=1e-6).to_float(compute_dtype) self.mlp = Mlp(embed_dims, mlp_ratio, dropout_rate, compute_dtype=compute_dtype) - self.filter = AFNO2D(h_size=h_size // patch_size, - w_size=w_size // patch_size, - embed_dims=embed_dims, - num_blocks=num_blocks, - high_freq=high_freq, - compute_dtype=compute_dtype) + self.filter = AFNO2D( + h_size=h_size // patch_size, + w_size=w_size // patch_size, + embed_dims=embed_dims, + num_blocks=num_blocks, + high_freq=high_freq, + compute_dtype=compute_dtype + ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def construct(self, x): @@ -218,15 +228,27 @@ class ForwardFeatures(nn.Cell): dropout_rate=1.0, compute_dtype=mstype.float16): super(ForwardFeatures, self).__init__() - self.patch_embed = PatchEmbed(in_channels, embed_dims, patch_size, compute_dtype=compute_dtype) + self.patch_embed = PatchEmbed( + in_channels, embed_dims, patch_size, compute_dtype=compute_dtype + ) self.pos_embed = Parameter( - initializer(TruncatedNormal(sigma=0.02), [1, grid_size[0] * grid_size[1], embed_dims], dtype=compute_dtype), + initializer( + TruncatedNormal(sigma=0.02), + [1, grid_size[0] * grid_size[1], embed_dims], + dtype=compute_dtype + ), requires_grad=True) self.layer = nn.CellList([]) self.encoder_norm = nn.LayerNorm([embed_dims], epsilon=1e-6).to_float(compute_dtype) for _ in range(depth): - self.layer.append(AFNOBlock(embed_dims, mlp_ratio, dropout_rate, h_size=h_size, w_size=w_size, - patch_size=patch_size, compute_dtype=compute_dtype)) + self.layer.append( + AFNOBlock( + embed_dims, mlp_ratio, dropout_rate, + h_size=h_size, w_size=w_size, + patch_size=patch_size, + compute_dtype=compute_dtype + ) + ) self.pos_drop = nn.Dropout(keep_prob=dropout_rate) self.h = grid_size[0] @@ -283,22 +305,45 @@ class AFNO2D(nn.Cell): self.h_size = h_size self.w_size = w_size - self.dft2_cell = dft2(shape=(h_size, w_size), dim=(-3, -2), - modes=(h_size // 2, w_size // 2 + 1), compute_dtype=compute_dtype) - self.idft2_cell = idft2(shape=(h_size, w_size), dim=(-3, -2), - modes=(h_size // 2, w_size // 2 + 1), compute_dtype=compute_dtype) + self.dft2_cell = dft2( + shape=(h_size, w_size), dim=(-3, -2), + modes=(h_size // 2, w_size // 2 + 1), compute_dtype=compute_dtype + ) + self.idft2_cell = idft2( + shape=(h_size, w_size), dim=(-3, -2), + modes=(h_size // 2, w_size // 2 + 1), compute_dtype=compute_dtype + ) self.scale = 0.02 self.num_blocks = num_blocks self.block_size = embed_dims // self.num_blocks self.hidden_size_factor = 1 - w1 = self.scale * Tensor(np.random.randn( - 2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor), dtype=compute_dtype) - b1 = self.scale * Tensor(np.random.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor), - dtype=compute_dtype) - w2 = self.scale * Tensor(np.random.randn( - 2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size), dtype=compute_dtype) - b2 = self.scale * Tensor(np.random.randn(2, self.num_blocks, self.block_size), dtype=compute_dtype) + w1 = self.scale * Tensor( + np.random.randn( + 2, self.num_blocks, self.block_size, + self.block_size * self.hidden_size_factor + ), + dtype=compute_dtype + ) + b1 = self.scale * Tensor( + np.random.randn( + 2, self.num_blocks, self.block_size * self.hidden_size_factor + ), + dtype=compute_dtype + ) + w2 = self.scale * Tensor( + np.random.randn( + 2, self.num_blocks, self.block_size * self.hidden_size_factor, + self.block_size + ), + dtype=compute_dtype + ) + b2 = self.scale * Tensor( + np.random.randn( + 2, self.num_blocks, self.block_size + ), + dtype=compute_dtype + ) self.w1 = Parameter(w1, requires_grad=True) self.b1 = Parameter(b1, requires_grad=True) @@ -341,21 +386,43 @@ class AFNO2D(nn.Cell): x_ft_re, x_ft_im = self.dft2_cell((x_re, x_im)) - x_ft_re = x_ft_re.reshape(b, x_ft_re.shape[1], x_ft_re.shape[2], self.num_blocks, self.block_size) - x_ft_im = x_ft_im.reshape(b, x_ft_im.shape[1], x_ft_im.shape[2], self.num_blocks, self.block_size) + x_ft_re = x_ft_re.reshape( + b, x_ft_re.shape[1], x_ft_re.shape[2], + self.num_blocks, self.block_size + ) + x_ft_im = x_ft_im.reshape( + b, x_ft_im.shape[1], x_ft_im.shape[2], + self.num_blocks, self.block_size + ) kept_modes = h // 2 + 1 - o1_real = self.relu(self.mul2d(x_ft_re, self.w1[0]) - self.mul2d(x_ft_im, self.w1[1]) + self.b1[0]) + o1_real = self.relu( + self.mul2d(x_ft_re, self.w1[0]) - + self.mul2d(x_ft_im, self.w1[1]) + + self.b1[0] + ) o1_real[:, :, kept_modes:] = 0.0 - o1_imag = self.relu(self.mul2d(x_ft_im, self.w1[0]) + self.mul2d(x_ft_re, self.w1[1]) + self.b1[1]) + o1_imag = self.relu( + self.mul2d(x_ft_im, self.w1[0]) + + self.mul2d(x_ft_re, self.w1[1]) + + self.b1[1] + ) o1_imag[:, :, kept_modes:] = 0.0 - o2_real = (self.mul2d(o1_real, self.w2[0]) - self.mul2d(o1_imag, self.w2[1]) + self.b2[0]) + o2_real = ( + self.mul2d(o1_real, self.w2[0]) - + self.mul2d(o1_imag, self.w2[1]) + + self.b2[0] + ) o2_real[:, :, kept_modes:] = 0.0 - o2_imag = (self.mul2d(o1_imag, self.w2[0]) + self.mul2d(o1_real, self.w2[1]) + self.b2[1]) + o2_imag = ( + self.mul2d(o1_imag, self.w2[0]) + + self.mul2d(o1_real, self.w2[1]) + + self.b2[1] + ) o2_imag[:, :, kept_modes:] = 0.0 o2_real = self.cast(o2_real, self.compute_type) diff --git a/mindscience/models/neural_operator/afnonet.py b/mindscience/models/neural_operator/afnonet.py index 8f7b41320f00c5266e7066e48c7d37503224f3ea..6c50416b7901bc15595faa31270dcf36e9b544df 100644 --- a/mindscience/models/neural_operator/afnonet.py +++ b/mindscience/models/neural_operator/afnonet.py @@ -38,8 +38,8 @@ class AFNONet(nn.Cell): encoder_embed_dim (int): The encoder embedding dimension of encoder layer. Default: 768. mlp_ratio (int): The rate of mlp layer. Default: 4. dropout_rate (float): The rate of dropout layer. Default: 1.0. - compute_dtype (dtype): The data type for encoder, decoding_embedding, decoder and dense layer. - Default: mindspore.float32. + compute_dtype (dtype): The data type for encoder, decoding_embedding, + decoder and dense layer. Default: mindspore.float32. Inputs: - **x** (Tensor) - Tensor of shape :math:`(batch\_size, feature\_size, image\_height, image\_width)`. @@ -77,7 +77,8 @@ class AFNONet(nn.Cell): super(AFNONet, self).__init__() image_size = to_2tuple(image_size) try: - grid_size = (image_size[0] // patch_size, image_size[1] // patch_size) + grid_size = (image_size[0] // patch_size, + image_size[1] // patch_size) except ZeroDivisionError: ops.Print()("Patch size can't be Zero") @@ -91,25 +92,40 @@ class AFNONet(nn.Cell): self.transpose = ops.Transpose() - self.forward_features = ForwardFeatures(grid_size=grid_size, - h_size=image_size[0], - w_size=image_size[1], - in_channels=in_channels, - patch_size=patch_size, - depth=encoder_depths, - embed_dims=encoder_embed_dim, - mlp_ratio=mlp_ratio, - dropout_rate=dropout_rate, - compute_dtype=compute_dtype) + self.forward_features = ForwardFeatures( + grid_size=grid_size, + h_size=image_size[0], + w_size=image_size[1], + in_channels=in_channels, + patch_size=patch_size, + depth=encoder_depths, + embed_dims=encoder_embed_dim, + mlp_ratio=mlp_ratio, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype + ) self.compute_type = compute_dtype - self.head = nn.Dense(encoder_embed_dim, patch_size ** 2 * out_channels, - weight_init=initializer(Normal(sigma=0.02), - shape=(patch_size ** 2 * out_channels, encoder_embed_dim)), - has_bias=False).to_float(compute_dtype) + self.head = nn.Dense( + encoder_embed_dim, patch_size ** 2 * out_channels, + weight_init=initializer( + Normal(sigma=0.02), + shape=(patch_size ** 2 * out_channels, encoder_embed_dim) + ), + has_bias=False + ).to_float(compute_dtype) def construct(self, x): + """ + Forward pass of the AFNONet model. + + Args: + x (Tensor): Input tensor of shape (batch_size, feature_size, image_height, image_width). + + Returns: + Tensor: Output tensor of shape (batch_size, patch_size, embed_dim). + """ x = self.forward_features(x) x = self.head(x) diff --git a/mindscience/models/neural_operator/dft.py b/mindscience/models/neural_operator/dft.py index 7da6d648326a9c5b98cee9b7b1a4476847a4064d..24fe07b9d4fee1fd02e762d5c5086fdd9405ffd0 100644 --- a/mindscience/models/neural_operator/dft.py +++ b/mindscience/models/neural_operator/dft.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - +""" +DFT +""" import numpy as np from scipy.linalg import dft @@ -61,7 +63,7 @@ class DFT1d(nn.Cell): else: self.dft_mat_res = self.dft_mat[:, -modes + 1:] - mat = Tensor(np.zeros(n, ), dtype=compute_dtype).reshape(n, 1) + mat = Tensor(np.zeros(n,), dtype=compute_dtype).reshape(n, 1) self.a_re_res = mindspore.numpy.flip( Tensor(self.dft_mat_res.real, dtype=compute_dtype), axis=-1) self.a_im_res = mindspore.numpy.flip( @@ -92,6 +94,7 @@ class DFT1d(nn.Cell): return y_re, y_im def construct(self, x): + """construct""" x_re, x_im = x x_re, x_im = P.Cast()(x_re, self.compute_dtype), P.Cast()(x_im, self.compute_dtype) if not self.inv: @@ -600,15 +603,27 @@ class SpectralConv2dDft(SpectralConvDft): x_im = ops.zeros_like(x_re) x_ft_re, x_ft_im = self._dft2_cell((x_re, x_im)) - out_ft_re1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1) - self._einsum( - x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1) - out_ft_im1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1) + self._einsum( - x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1) - - out_ft_re2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2) - self._einsum( - x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2) - out_ft_im2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2) + self._einsum( - x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2) + out_ft_re1 = self._einsum( + x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1 + ) - self._einsum( + x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1 + ) + out_ft_im1 = self._einsum( + x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1 + ) + self._einsum( + x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1 + ) + + out_ft_re2 = self._einsum( + x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2 + ) - self._einsum( + x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2 + ) + out_ft_im2 = self._einsum( + x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2 + ) + self._einsum( + x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2 + ) batch_size = x.shape[0] mat = self._mat.repeat(batch_size, 0) @@ -681,30 +696,62 @@ class SpectralConv3dDft(SpectralConvDft): x_im = ops.zeros_like(x_re) x_ft_re, x_ft_im = self._dft3_cell((x_re, x_im)) - out_ft_re1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], - self._w_re1) - self._einsum(x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], - :self.n_modes[2]], self._w_im1) - out_ft_im1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], - self._w_im1) + self._einsum(x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], - :self.n_modes[2]], self._w_re1) - out_ft_re2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], - self._w_re2) - self._einsum(x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], - :self.n_modes[2]], self._w_im2) - out_ft_im2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], - self._w_im2) + self._einsum(x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], - :self.n_modes[2]], self._w_re2) - out_ft_re3 = self._einsum(x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], - self._w_re3) - self._einsum(x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, - :self.n_modes[2]], self._w_im3) - out_ft_im3 = self._einsum(x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], - self._w_im3) + self._einsum(x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, - :self.n_modes[2]], self._w_re3) - out_ft_re4 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], - self._w_re4) - self._einsum(x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, - :self.n_modes[2]], self._w_im4) - out_ft_im4 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], - self._w_im4) + self._einsum(x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, - :self.n_modes[2]], self._w_re4) + out_ft_re1 = self._einsum( + x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], + self._w_re1 + ) - self._einsum( + x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], + self._w_im1 + ) + out_ft_im1 = self._einsum( + x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], + self._w_im1 + ) + self._einsum( + x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], + self._w_re1 + ) + out_ft_re2 = self._einsum( + x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], + self._w_re2 + ) - self._einsum( + x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], + self._w_im2 + ) + out_ft_im2 = self._einsum( + x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], + self._w_im2 + ) + self._einsum( + x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], + self._w_re2 + ) + out_ft_re3 = self._einsum( + x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], + self._w_re3 + ) - self._einsum( + x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], + self._w_im3 + ) + out_ft_im3 = self._einsum( + x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], + self._w_im3 + ) + self._einsum( + x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], + self._w_re3 + ) + out_ft_re4 = self._einsum( + x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], + self._w_re4 + ) - self._einsum( + x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], + self._w_im4 + ) + out_ft_im4 = self._einsum( + x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], + self._w_im4 + ) + self._einsum( + x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], + self._w_re4 + ) batch_size = x.shape[0] mat_x = self._mat_x.repeat(batch_size, 0) diff --git a/mindscience/models/neural_operator/fno.py b/mindscience/models/neural_operator/fno.py index 257aa5d29d574947aa7c4143efe4395320491b9e..d4b254f76acea42cca845196913255d3d5286cc0 100644 --- a/mindscience/models/neural_operator/fno.py +++ b/mindscience/models/neural_operator/fno.py @@ -119,33 +119,21 @@ class FNOBlocks(nn.Cell): ).to_float(self.fno_compute_dtype) elif len(self.resolutions) == 2: self._convs = SpectralConv2dDft( - self.in_channels, - self.out_channels, - self.n_modes, - self.resolutions, - compute_dtype=self.dft_compute_dtype + self.in_channels, self.out_channels, self.n_modes, + self.resolutions, compute_dtype=self.dft_compute_dtype ) self._fno_skips = nn.Conv2d( - self.in_channels, - self.out_channels, - kernel_size=1, - has_bias=False, - weight_init="HeUniform" + self.in_channels, self.out_channels, kernel_size=1, + has_bias=False, weight_init="HeUniform" ).to_float(self.fno_compute_dtype) elif len(self.resolutions) == 3: self._convs = SpectralConv3dDft( - self.in_channels, - self.out_channels, - self.n_modes, - self.resolutions, - compute_dtype=self.dft_compute_dtype + self.in_channels, self.out_channels, self.n_modes, + self.resolutions, compute_dtype=self.dft_compute_dtype ) self._fno_skips = nn.Conv3d( - self.in_channels, - self.out_channels, - kernel_size=1, - has_bias=False, - weight_init="HeUniform" + self.in_channels, self.out_channels, kernel_size=1, + has_bias=False, weight_init="HeUniform" ).to_float(self.fno_compute_dtype) else: raise ValueError("The length of input resolutions dimensions should be in [1, 2, 3], but got: {}".format( @@ -531,12 +519,14 @@ class FNO2D(FNO): resolutions = [resolutions, resolutions] if len(n_modes) != 2: raise ValueError( - "The dimension of n_modes should be equal to 2 when using FNO2D\ - but got dimension of n_modes {}".format(len(n_modes))) + "The dimension of n_modes should be equal to 2 when using FNO2D " + "but got dimension of n_modes {}".format(len(n_modes)) + ) if len(resolutions) != 2: raise ValueError( - "The dimension of resolutions should be equal to 2 when using FNO2D\ - but got dimension of resolutions {}".format(len(resolutions))) + "The dimension of resolutions should be equal to 2 when using FNO2D " + "but got dimension of resolutions {}".format(len(resolutions)) + ) super().__init__( in_channels, out_channels, @@ -647,12 +637,14 @@ class FNO3D(FNO): resolutions = [resolutions, resolutions, resolutions] if len(n_modes) != 3: raise ValueError( - "The dimension of n_modes should be equal to 3 when using FNO3D\ - but got dimension of n_modes {}".format(len(n_modes))) + "The dimension of n_modes should be equal to 3 when using FNO3D " + "but got dimension of n_modes {}".format(len(n_modes)) + ) if len(resolutions) != 3: raise ValueError( - "The dimension of resolutions should be equal to 3 when using FNO3D\ - but got dimension of resolutions {}".format(len(resolutions))) + "The dimension of resolutions should be equal to 3 when using FNO3D " + "but got dimension of resolutions {}".format(len(resolutions)) + ) super().__init__( in_channels, out_channels, diff --git a/mindscience/models/transformer/__init__.py b/mindscience/models/transformer/__init__.py index b7300a82d71d04011afd22afef58fe15c4d95722..9774942d0f8cf2487ffc6f0fe831cb95e3fb28f0 100644 --- a/mindscience/models/transformer/__init__.py +++ b/mindscience/models/transformer/__init__.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""init""" +""" +Transformer Models Package + +This package contains implementations of various Transformer architectures, +including Vision Transformer (ViT) and related attention mechanisms. +""" from .attention import Attention, MultiHeadAttention, TransformerBlock -from .vit import ViT +from .vit import VisionTransformer -__all__ = ["Attention", "MultiHeadAttention", "TransformerBlock", "ViT"] +__all__ = ["Attention", "MultiHeadAttention", "TransformerBlock", "VisionTransformer"] diff --git a/mindscience/models/transformer/attention.py b/mindscience/models/transformer/attention.py index d19fff9cb9548e2510670ae7f718a94f5ea73bde..72e9ef258cbe96341fadc60c9e02e9deba87f8e0 100644 --- a/mindscience/models/transformer/attention.py +++ b/mindscience/models/transformer/attention.py @@ -31,7 +31,8 @@ class Attention(nn.Cell): Inputs: - **x** (Tensor) - Tensor with shape :math:`(batch\_size, sequence\_len, in\_channels)`. - **attn_mask** (Tensor, optional) - Tensor with shape :math:`(sequence\_len, sequence\_len)` or - or :math:`(batch\_size, 1, sequence\_len, sequence\_len)`. Default: ``None``. + or :math:`(batch\_size, 1, sequence\_len, sequence\_len)`. + Default: ``None``. - **key_padding_mask** (Tensor, optional) - Tensor with shape :math:`(batch\_size, sequence\_len)`. Default: ``None``. @@ -72,14 +73,17 @@ class Attention(nn.Cell): elif len(attn_mask.shape) == 4: pass else: - raise Exception(f'attn_mask shape {attn_mask.shape} not support') + raise Exception( + f'attn_mask shape {attn_mask.shape} not support') mask = mask + attn_mask.astype(mstype.uint8) if key_padding_mask is not None: batch, node = key_padding_mask.shape[0], key_padding_mask.shape[-1] if len(key_padding_mask.shape) == 2: - key_padding_mask = ops.broadcast_to(key_padding_mask.unsqueeze(1), (batch, node, node)).unsqueeze(1) + key_padding_mask = ops.broadcast_to( + key_padding_mask.unsqueeze(1), (batch, node, node)).unsqueeze(1) else: - raise Exception(f'key_padding_mask shape {attn_mask.shape} not support') + raise Exception( + f'key_padding_mask shape {attn_mask.shape} not support') mask = mask + key_padding_mask.astype(mstype.uint8) return mask @@ -95,8 +99,8 @@ class Attention(nn.Cell): """get qkv value""" b, n, _ = x.shape qkv = ( - self.qkv(x).reshape(b, n, 3, self.num_heads, - - 1).transpose((2, 0, 3, 1, 4)) + self.qkv(x).reshape(b, n, 3, self.num_heads, + - 1).transpose((2, 0, 3, 1, 4)) ) return qkv[0], qkv[1], qkv[2] @@ -168,7 +172,8 @@ class FlashAttn(nn.Cell): self.scale = scale def construct(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None): - query, key, value = query.astype(self.fa_dtype), key.astype(self.fa_dtype), value.astype(self.fa_dtype) + query, key, value = query.astype(self.fa_dtype), key.astype( + self.fa_dtype), value.astype(self.fa_dtype) if mask is not None: mask = mask.astype(mstype.uint8) scores = ops.flash_attention_score(query, key, value, input_layout='BNSD', head_num=self.num_heads, @@ -183,18 +188,22 @@ class MultiHeadAttention(Attention): in_channels (int): The input channels. num_heads (int): The number of attention heads. enable_flash_attn (bool): Whether use flash attention. FlashAttention only supports Ascend backend. - FlashAttention proposed in `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness `_. + FlashAttention proposed in + `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness + `_. Default: ``False``. fa_dtype (mindspore.dtype): FlashAttention compute dtype. Choose from `mstype.bfloat16`, `mstype.float16`. Default: ``mstype.bfloat16``, indicates ``mindspore.bfloat16``. - drop_mode (str): Dropout method, ``dropout`` or ``droppath``. Default: ``dropout``. + drop_mode (str): Dropout method, ``dropout`` or ``droppath``. + Default: ``dropout``. dropout_rate (float): The drop rate of dropout layer, greater than 0 and less equal than 1. Default: ``0.0``. compute_dtype (mindspore.dtype): Compute dtype. Default: ``mstype.float32``, indicates ``mindspore.float32``. Inputs: - **x** (Tensor) - Tensor with shape :math:`(batch\_size, sequence\_len, in\_channels)`. - - **attn_mask** (Tensor, optional) - Tensor with shape :math:`(sequence\_len, sequence\_len)` or - or :math:`(batch\_size, 1, sequence\_len, sequence\_len)`. Default: ``None``. + - **attn_mask (Tensor, optional) - Tensor with shape :math:`(sequence\_len, sequence\_len)` or + :math:`(batch\_size, 1, sequence\_len, sequence\_len)`. + Default: ``None``. - **key_padding_mask** (Tensor, optional) - Tensor with shape :math:`(batch\_size, sequence\_len)`. Default: ``None``. @@ -232,7 +241,8 @@ class MultiHeadAttention(Attention): self.proj = nn.Dense(in_channels, in_channels).to_float(compute_dtype) if enable_flash_attn: print('use flash attention') - self.attn = FlashAttn(num_heads=num_heads, scale=scale, fa_dtype=fa_dtype) + self.attn = FlashAttn(num_heads=num_heads, + scale=scale, fa_dtype=fa_dtype) else: self.attn = ScaledDot(scale=scale) if drop_mode == "dropout": @@ -254,10 +264,13 @@ class MultiHeadAttention(Attention): class FeedForward(nn.Cell): """FeedForward""" + def __init__(self, in_channels, dropout_rate=0.0, compute_dtype=mstype.float16): super().__init__() - self.fc1 = nn.Dense(in_channels, in_channels * 4).to_float(compute_dtype) - self.fc2 = nn.Dense(in_channels * 4, in_channels).to_float(compute_dtype) + self.fc1 = nn.Dense(in_channels, in_channels + * 4).to_float(compute_dtype) + self.fc2 = nn.Dense( + in_channels * 4, in_channels).to_float(compute_dtype) self.act_fn = nn.GELU() self.dropout = nn.Dropout(p=dropout_rate) @@ -278,7 +291,9 @@ class TransformerBlock(nn.Cell): in_channels (int): The input channels. num_heads (int): The number of attention heads. enable_flash_attn (bool): Whether use flash attention. FlashAttention only supports Ascend backend. - FlashAttention proposed in `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness `_. + FlashAttention proposed in + `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness + `_. Default: ``False``. fa_dtype (mindspore.dtype): FlashAttention compute dtype. Choose from `mstype.bfloat16`, `mstype.float16`. Default: ``mstype.bfloat16``, indicates ``mindspore.bfloat16``. diff --git a/mindscience/models/transformer/vit.py b/mindscience/models/transformer/vit.py index 9923a668ce8f321c0c755ac2652290692e91a03f..ac7250b2da365331b027c90741bc2478b646a7ba 100644 --- a/mindscience/models/transformer/vit.py +++ b/mindscience/models/transformer/vit.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """ -The ViT model +The VisionTransformer model """ from mindspore import ops, Parameter, Tensor, nn @@ -62,9 +62,9 @@ class PatchEmbedding(nn.Cell): return x -class VitEncoder(nn.Cell): +class VisionTransformerEncoder(nn.Cell): r""" - ViT Encoder module with multi-layer stacked of `MultiHeadAttention`, + VisionTransformer Encoder module with multi-layer stacked of `MultiHeadAttention`, including multihead self attention and feedforward layer. Args: @@ -76,7 +76,7 @@ class VitEncoder(nn.Cell): num_heads (int): The encoder heads' number of encoder layer. Default: ``16``. dropout_rate (float): The rate of dropout layer. Default: ``0.0``. compute_dtype (dtype): The data type for encoder, encoding_embedding, encoder and dense layer. - Default: ``mstype.float16``. + Default: ``mstype.float16``. Inputs: - **input** (Tensor) - Tensor of shape :math:`(batch\_size, feature\_size, image\_height, image\_width)`. @@ -90,11 +90,11 @@ class VitEncoder(nn.Cell): Examples: >>> from mindspore import ops - >>> from mindflow.cell.vit import VitEncoder + >>> from mindflow.cell.vision_transformer import VisionTransformerEncoder >>> input_tensor = ops.rand(32, 3, 192, 384) >>> print(input_tensor.shape) (32, 3, 192, 384) - >>>encoder = VitEncoder(grid_size=(192 // 16, 384 // 16), + >>>encoder = VisionTransformerEncoder(grid_size=(192 // 16, 384 // 16), >>> in_channels=3, >>> patch_size=16, >>> depths=6, @@ -150,9 +150,9 @@ class VitEncoder(nn.Cell): return x -class VitDecoder(nn.Cell): +class VisionTransformerDecoder(nn.Cell): r""" - ViT Decoder module with multi-layer stacked of `MultiHeadAttention`, + VisionTransformer Decoder module with multi-layer stacked of `MultiHeadAttention`, including multihead self attention and feedforward layer. Args: @@ -162,7 +162,7 @@ class VitDecoder(nn.Cell): num_heads (int): The decoder heads' number of decoder layer. dropout_rate (float): The rate of dropout layer. Default: ``0.0``. compute_dtype (dtype): The data type for encoder, decoding_embedding, decoder and dense layer. - Default: ``mstype.float16``. + Default: ``mstype.float16``. Inputs: - **input** (Tensor) - Tensor of shape :math:`(batch\_size, patchify\_size, embed\_dim)`. @@ -176,17 +176,17 @@ class VitDecoder(nn.Cell): Examples: >>> from mindspore import ops - >>> from mindflow.cell.vit import VitDecoder + >>> from mindflow.cell.vision_transformer import VisionTransformerDecoder >>> input_tensor = ops.rand(32, 288, 512) >>> print(input_tensor.shape) (32, 288, 768) - >>> decoder = VitDecoder(grid_size=grid_size, + >>> decoder = VisionTransformerDecoder(grid_size=grid_size, >>> depths=6, >>> hidden_channels=512, >>> num_heads=16, >>> dropout_rate=0.0, >>> compute_dtype=mstype.float16) - >>> output_tensor = VitDecoder(input_tensor) + >>> output_tensor = VisionTransformerDecoder(input_tensor) >>> print("output_tensor:",output_tensor.shape) (32, 288, 512) """ @@ -229,9 +229,10 @@ class VitDecoder(nn.Cell): return x -class ViT(nn.Cell): +class VisionTransformer(nn.Cell): r""" - This module based on ViT backbone which including encoder, decoding_embedding, decoder and dense layer. + This module based on VisionTransformer backbone which including encoder, decoding_embedding, + decoder and dense layer. Args: image_size (tuple[int]): The image size of input. Default: ``(192, 384)``. @@ -246,7 +247,7 @@ class ViT(nn.Cell): decoder_num_heads (int): The decoder heads' number of decoder layer. Default: ``16``. dropout_rate (float): The rate of dropout layer. Default: ``0.0``. compute_dtype (dtype): The data type for encoder, decoding_embedding, decoder and dense layer. - Default: ``mstype.float16``. + Default: ``mstype.float16``. Inputs: - **input** (Tensor) - Tensor of shape :math:`(batch\_size, feature\_size, image\_height, image\_width)`. @@ -260,11 +261,11 @@ class ViT(nn.Cell): Examples: >>> from mindspore import ops - >>> from mindflow.cell import ViT + >>> from mindflow.cell import VisionTransformer >>> input_tensor = ops.rand(32, 3, 192, 384) >>> print(input_tensor.shape) (32, 3, 192, 384) - >>> model = ViT(in_channels=3, + >>> model = VisionTransformer(in_channels=3, >>> out_channels=3, >>> encoder_depths=6, >>> encoder_embed_dim=768, @@ -310,7 +311,7 @@ class ViT(nn.Cell): self.transpose = ops.Transpose() - self.encoder = VitEncoder( + self.encoder = VisionTransformerEncoder( in_channels=in_channels, hidden_channels=encoder_embed_dim, patch_size=patch_size, @@ -328,7 +329,7 @@ class ViT(nn.Cell): weight_init="XavierUniform", ).to_float(compute_dtype) - self.decoder = VitDecoder( + self.decoder = VisionTransformerDecoder( hidden_channels=decoder_embed_dim, grid_size=grid_size, depths=decoder_depths, diff --git a/mindscience/solvers/cfd/space_solver/riemann_computer/rusanov_net.py b/mindscience/solvers/cfd/space_solver/riemann_computer/rusanov_net.py index d3a99a1a9f3e56367afb6250d193bb76b76a8146..354b9fe2e35e975f40961c1d61c43014b91962f2 100644 --- a/mindscience/solvers/cfd/space_solver/riemann_computer/rusanov_net.py +++ b/mindscience/solvers/cfd/space_solver/riemann_computer/rusanov_net.py @@ -23,8 +23,10 @@ from .base import RiemannComputer @jit_class class RusanovNet(RiemannComputer): r""" - Rusanov Riemann computer with network. The network is inspired by Rusanov_Net from paper "JAX-FLUIDS: A - fully-differentiable high-order computational fluid dynamics solver for compressible two-phase flows" + Rusanov Riemann computer with network. The network is inspired by Rusanov_Net + from paper "JAX-FLUIDS: A fully-differentiable high-order computational + fluid dynamics solver for compressible two-phase flows" + https://arxiv.org/pdf/2203.13760.pdf Args: @@ -69,23 +71,29 @@ class RusanovNet(RiemannComputer): pri_var_left = cal_pri_var(con_var_left, self.material) pri_var_right = cal_pri_var(con_var_right, self.material) - flux_left = cal_flux(con_var_left, pri_var_left, axis) - flux_right = cal_flux(con_var_right, pri_var_right, axis) + flux_left = cal_flux(con_var_left, pri_var_left, + axis) + flux_right = cal_flux(con_var_right, pri_var_right, + axis) sound_speed_left = self.material.sound_speed(pri_var_left) sound_speed_right = self.material.sound_speed(pri_var_right) mean_sound_speed = 0.5 * (sound_speed_left + sound_speed_right) delta_vel = mnp.abs(pri_var_right[axis + 1] - pri_var_left[axis + 1]) - mean_vel = 0.5 * (mnp.abs(pri_var_right[axis + 1]) + mnp.abs(pri_var_left[axis + 1])) + mean_vel = 0.5 * \ + (mnp.abs(pri_var_right[axis + 1]) + + mnp.abs(pri_var_left[axis + 1])) delta_sound_speed = mnp.abs(sound_speed_left - sound_speed_right) - var = mnp.stack([delta_vel, mean_vel, mean_sound_speed, delta_sound_speed], axis=0) + var = mnp.stack([delta_vel, mean_vel, mean_sound_speed, + delta_sound_speed], axis=0) var = self.transpose(var, (3, 1, 2, 0)) net_out = mnp.exp(self.net(var)) net_out = self.transpose(net_out, (3, 1, 2, 0)) - flux = 0.5 * (flux_left + flux_right) - net_out * (con_var_right - con_var_left) + flux = 0.5 * (flux_left + flux_right) - net_out * \ + (con_var_right - con_var_left) return flux diff --git a/mindscience/utils/__init__.py b/mindscience/utils/__init__.py index b12ee36e716425843339f12cd06827b30e139a63..ed45dda221093c07a9d615ea4bf737013a4780dd 100644 --- a/mindscience/utils/__init__.py +++ b/mindscience/utils/__init__.py @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""init""" +"""Utilities Package + +This package contains various utility functions and classes used across the +MindScience toolkit. It includes configuration loading, logging utilities, +time utilities, and parameter checking functions. +""" from .load_config import load_yaml_config from .log_utils import print_log, log_config from .time_utils import log_timer -from .check_func import check_dict_type, check_dict_value, check_param_type, check_param_type_value, check_param_value, check_dict_type_value - -__all__ = ["load_yaml_config", "print_log", "log_config", "log_timer", "check_param_type", "check_param_type_value",\ - "check_param_value", "check_dict_type_value", "check_dict_type", "check_dict_value"] - +from .check_func import (check_dict_type, check_dict_value, check_param_type, + check_param_type_value, check_param_value, + check_dict_type_value) +__all__ = ["load_yaml_config", "print_log", "log_config", "log_timer", + "check_param_type", "check_param_type_value", "check_param_value", + "check_dict_type_value", "check_dict_type", "check_dict_value"] diff --git a/mindscience/utils/check_func.py b/mindscience/utils/check_func.py index d5513016e43e15827a5838b5512c2905e4e42598..9cc5c051abfe6fa4c82c04370e2f9905c9b5a760 100644 --- a/mindscience/utils/check_func.py +++ b/mindscience/utils/check_func.py @@ -12,13 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""functions""" +""" +Parameter Checking Functions + + +This module provides utility functions for checking parameters in the MindScience toolkit. +It includes functions for validating parameter types, values, and dictionary structures. +""" +# pylint: disable=C0123 from __future__ import absolute_import from mindspore import context _SPACE = " " -__all__ = ["check_param_type", "check_param_type_value", "check_param_value", "check_dict_type_value", "check_dict_type"] +__all__ = ["check_param_type", "check_param_type_value", + "check_param_value", "check_dict_type_value", "check_dict_type"] def _convert_to_tuple(params): @@ -37,25 +45,32 @@ def check_param_type(param, param_name, data_type=None, exclude_type=None): exclude_type = _convert_to_tuple(exclude_type) if data_type and not isinstance(param, data_type): - raise TypeError("The type of {} should be instance of {}, but got {} with type {}".format( - param_name, data_type, param, type(param))) + raise TypeError( + "The type of {} should be instance of {}, but got {} " + "with type {}".format(param_name, data_type, param, type(param)) + ) if exclude_type and type(param) in exclude_type: - raise TypeError("The type of {} should not be instance of {}, but got {} with type {}".format( - param_name, exclude_type, param, type(param))) - return None + raise TypeError( + "The type of {} should not be instance of {}, but got {} " + "with type {}".format(param_name, exclude_type, param, type(param)) + ) def check_param_value(param, param_name, valid_value): """check parameter's value""" valid_value = _convert_to_tuple(valid_value) if param not in valid_value: - raise ValueError("The value of {} should be in {}, but got {}".format( - param_name, valid_value, param)) + raise ValueError( + "The value of {} should be in {}, but got {}".format( + param_name, valid_value, param + ) + ) def check_param_type_value(param, param_name, valid_value, data_type=None, exclude_type=None): """check both data type and value""" - check_param_type(param, param_name, data_type=data_type, exclude_type=exclude_type) + check_param_type(param, param_name, data_type=data_type, + exclude_type=exclude_type) check_param_value(param, param_name, valid_value) @@ -65,12 +80,16 @@ def check_dict_type(param_dict, param_name, key_type=None, value_type=None): for key in param_dict.keys(): if key_type: - check_param_type(key, _SPACE.join(("key of", param_name)), data_type=key_type) + check_param_type( + key, _SPACE.join(("key of", param_name)), + data_type=key_type + ) if value_type: values = _convert_to_tuple(param_dict[key]) for value in values: - check_param_type(value, _SPACE.join(("value of", param_name)), data_type=value_type) - return None + check_param_type( + value, _SPACE.join(("value of", param_name)), data_type=value_type + ) def check_dict_value(param_dict, param_name, key_value=None, value_value=None): @@ -79,55 +98,84 @@ def check_dict_value(param_dict, param_name, key_value=None, value_value=None): for key in param_dict.keys(): if key_value: - check_param_value(key, _SPACE.join(("key of", param_name)), key_value) + check_param_value( + key, _SPACE.join(("key of", param_name)), + key_value + ) if value_value: values = _convert_to_tuple(param_dict[key]) for value in values: - check_param_value(value, _SPACE.join(("value of", param_name)), value_value) - return None + check_param_value( + value, _SPACE.join(("value of", param_name)), value_value + ) def check_dict_type_value(param_dict, param_name, key_type=None, value_type=None, key_value=None, value_value=None): """check values for key and value of specified dict""" - check_dict_type(param_dict, param_name, key_type=key_type, value_type=value_type) - check_dict_value(param_dict, param_name, key_value=key_value, value_value=value_value) - return None + check_dict_type(param_dict, param_name, + key_type=key_type, value_type=value_type) + check_dict_value(param_dict, param_name, + key_value=key_value, value_value=value_value) def check_mode(api_name): """check running mode""" if context.get_context("mode") == context.PYNATIVE_MODE: - raise RuntimeError("{} is only supported GRAPH_MODE now but got PYNATIVE_MODE".format(api_name)) + raise RuntimeError( + "{} is only supported GRAPH_MODE now but got PYNATIVE_MODE".format(api_name)) def check_param_no_greater(param, param_name, compared_value): """ Check whether the param less than the given compared_value""" if param > compared_value: - raise ValueError("The value of {} should be no greater than {}, but got {}".format( - param_name, compared_value, param)) + raise ValueError( + "The value of {} should be no greater than {}, but got {}".format( + param_name, compared_value, param + ) + ) def check_param_odd(param, param_name): """ Check whether the param is an odd number""" if param % 2 == 0: - raise ValueError("The value of {} should be an odd number, but got {}".format( - param_name, param)) + raise ValueError( + "The value of {} should be an odd number, but got {}".format( + param_name, param + ) + ) def check_param_even(param, param_name): """ Check whether the param is an even number""" for value in param: if value % 2 != 0: - raise ValueError("The value of {} should be an even number, but got {}".format( - param_name, param)) + raise ValueError( + "The value of {} should be an even number, but got {}".format( + param_name, param + ) + ) -def check_lr_param_type_value(param, param_name, param_type, thresh_hold=0, restrict=False, exclude=None): +def check_lr_param_type_value(param, param_name, param_type, thresh_hold=0, + restrict=False, exclude=None): + """Check the type and value of the learning rate parameter.""" if (exclude and isinstance(param, exclude)) or not isinstance(param, param_type): - raise TypeError("the type of {} should be {}, but got {}".format(param_name, param_type, type(param))) + raise TypeError( + "the type of {} should be {}, but got {}".format( + param_name, param_type, type(param) + ) + ) if restrict: if param <= thresh_hold: - raise ValueError("the value of {} should be > {}, but got: {}".format(param_name, thresh_hold, param)) + raise ValueError( + "the value of {} should be > {}, but got: {}".format( + param_name, thresh_hold, param + ) + ) else: if param < thresh_hold: - raise ValueError("the value of {} should be >= {}, but got: {}".format(param_name, thresh_hold, param)) + raise ValueError( + "the value of {} should be >= {}, but got: {}".format( + param_name, thresh_hold, param + ) + ) diff --git a/mindscience/utils/log_utils.py b/mindscience/utils/log_utils.py index a284a40ae492142946d4534d34fb367a19a8dbc9..9f9b96560bb533dd7a1436b42b7ddad3ad2f7ee2 100644 --- a/mindscience/utils/log_utils.py +++ b/mindscience/utils/log_utils.py @@ -1,4 +1,9 @@ -"""log utils""" +"""Logging Utilities + +This module provides utility functions for logging in the MindScience toolkit. +It includes functions for configuring logging and printing messages to both +standard output and log files. +""" import logging import os @@ -13,11 +18,12 @@ def log_config(log_dir='./logs', model_name="model", permission=0o644): if not os.path.exists(log_dir): os.mkdir(log_dir) log_path = os.path.join(log_dir, f"{model_name}.log") - logging.basicConfig(level=logging.INFO, - format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', - datefmt='%a, %d %b %Y %H:%M:%S', - filename=log_path, - filemode='w') + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', + datefmt='%a, %d %b %Y %H:%M:%S', + filename=log_path, + filemode='w') os.chmod(log_path, permission) @@ -27,8 +33,8 @@ def print_log(*msg, level=logging.INFO, enable_log=True): Args: *msg (any): Message(s) to print and log. level (int): Log level. Default: logging.INFO. - enable_log (bool): Whether to log the message. In some cases, like before logging configuration, this flag would - be set as False. Default: ``True``. + enable_log (bool): Whether to log the message. In some cases, like before logging + configuration, this flag would be set as False. Default: ``True``. """ def log_help_func(*messages): diff --git a/mindscience/utils/time_utils.py b/mindscience/utils/time_utils.py index c119baf3e79a900f6470c6713d50a71d17458f68..5704b58cbdadca43aeef2dfadc682ec3597e28d0 100644 --- a/mindscience/utils/time_utils.py +++ b/mindscience/utils/time_utils.py @@ -1,4 +1,8 @@ -"""time utils""" +"""Time Utilities + +This module provides utility functions for time-related operations in the MindScience toolkit. +It includes a decorator for measuring execution time of functions. +""" import time from .log_utils import print_log diff --git a/models.md b/models.md index 5da891b29ba998730c797e6c9d0b32e62b6f8c43..68914d9971e1c1924d1cf09967f4139f69a30800 100644 --- a/models.md +++ b/models.md @@ -62,7 +62,7 @@ | 计算流体 | [CAE-LSTM](https://doi.org/10.13700/j.bh.1001-5965.2022.0085) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/research/cae_lstm/README_CN.md#) | ✅ | ✅ | | 计算流体 | [eHDNN](https://doi.org/10.1016/j.ast.2022.107636) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/research/transonic_buffet_ehdnn/README_CN.md#) | ✅ | ✅ | | 计算流体 | [HDNN](https://doi.org/10.1016/j.ast.2022.107636) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/research/move_boundary_hdnn/README_CN.md#) | ✅ | ✅ | -| 计算流体 | [ViT](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_driven/airfoil/2D_steady/2D_steady_CN.ipynb) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_driven/airfoil/2D_steady/README_CN.md#) | ✅ | ✅ | +| 计算流体 | [VisionTransformer](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_driven/airfoil/2D_steady/2D_steady_CN.ipynb) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_driven/airfoil/2D_steady/README_CN.md#) | ✅ | ✅ | | 计算流体 | [PeRCNN](https://www.nature.com/articles/s42256-023-00685-7) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_mechanism_fusion/percnn/README_CN.md#) | ✅ | ✅ | | 计算流体 | [Burgers1D](https://www.sciencedirect.com/science/article/abs/pii/S0021999118307125) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/physics_driven/burgers/README_CN.md#) | ✅ | ✅ | | 计算流体 | [Cylinder Flow](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/physics_driven/navier_stokes/cylinder_flow_forward/navier_stokes2D_CN.ipynb) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/physics_driven/navier_stokes/cylinder_flow_forward/README_CN.md#) | ✅ | ✅ | diff --git a/models_en.md b/models_en.md index 4842f9deec4e19fb97b1e395dbdb77f37f3cfe96..fbe8c9c7cda818729842069af51a9ec81989c74f 100644 --- a/models_en.md +++ b/models_en.md @@ -62,7 +62,7 @@ | Computational Fluid Dynamics | [CAE-LSTM](https://doi.org/10.13700/j.bh.1001-5965.2022.0085) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/research/cae_lstm/README.md#) | ✅ | ✅ | | Computational Fluid Dynamics | [eHDNN](https://doi.org/10.1016/j.ast.2022.107636) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/research/transonic_buffet_ehdnn/README.md#) | ✅ | ✅ | | Computational Fluid Dynamics | [HDNN](https://doi.org/10.1016/j.ast.2022.107636) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/research/move_boundary_hdnn/README.md#) | ✅ | ✅ | -| Computational Fluid Dynamics | [ViT](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_driven/airfoil/2D_steady/2D_steady.ipynb) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_driven/airfoil/2D_steady/README.MD#) | ✅ | ✅ | +| Computational Fluid Dynamics | [VisionTransformer](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_driven/airfoil/2D_steady/2D_steady.ipynb) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_driven/airfoil/2D_steady/README.MD#) | ✅ | ✅ | | Computational Fluid Dynamics | [PeRCNN](https://www.nature.com/articles/s42256-023-00685-7) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_mechanism_fusion/percnn/README.md#) | ✅ | ✅ | | Computational Fluid Dynamics | [Burgers1D](https://www.sciencedirect.com/science/article/abs/pii/S0021999118307125) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/physics_driven/burgers/README.md#) | ✅ | ✅ | | Computational Fluid Dynamics | [Cylinder Flow](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/physics_driven/navier_stokes/cylinder_flow_forward/navier_stokes2D.ipynb) | [Link](https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/physics_driven/navier_stokes/cylinder_flow_forward/README.md#) | ✅ | ✅ | diff --git a/tests/models/diffusion/ae.py b/tests/models/diffuser/ae.py similarity index 100% rename from tests/models/diffusion/ae.py rename to tests/models/diffuser/ae.py diff --git a/tests/models/diffusion/dataset.py b/tests/models/diffuser/dataset.py similarity index 100% rename from tests/models/diffusion/dataset.py rename to tests/models/diffuser/dataset.py diff --git a/tests/models/diffusion/ddim_gt.py b/tests/models/diffuser/ddim_gt.py similarity index 100% rename from tests/models/diffusion/ddim_gt.py rename to tests/models/diffuser/ddim_gt.py diff --git a/tests/models/diffusion/ddpm_gt.py b/tests/models/diffuser/ddpm_gt.py similarity index 100% rename from tests/models/diffusion/ddpm_gt.py rename to tests/models/diffuser/ddpm_gt.py diff --git a/tests/models/diffusion/test_diffusion.py b/tests/models/diffuser/test_diffusion.py similarity index 99% rename from tests/models/diffusion/test_diffusion.py rename to tests/models/diffuser/test_diffusion.py index 86413171871e36054f7084f0b2f797fe7b801d5c..39c50f08c4bbacedb60c6c05a2050aad3bc673e2 100644 --- a/tests/models/diffusion/test_diffusion.py +++ b/tests/models/diffuser/test_diffusion.py @@ -22,7 +22,7 @@ import numpy as np from mindspore import Tensor, ops, context from mindspore import dtype as mstype -from mindscience.models import DiffusionScheduler, DDPMPipeline, DDIMPipeline, DDPMScheduler, DDIMScheduler, \ +from mindscience.diffuser import DiffusionScheduler, DDPMPipeline, DDIMPipeline, DDPMScheduler, DDIMScheduler, \ DiffusionTransformer, ConditionDiffusionTransformer PROJECT_ROOT = os.path.abspath(os.path.join( diff --git a/tests/models/diffusion/test_diffusion_train.py b/tests/models/diffuser/test_diffusion_train.py similarity index 98% rename from tests/models/diffusion/test_diffusion_train.py rename to tests/models/diffuser/test_diffusion_train.py index 47567bd484eecc20632397e64626bf9cf80a2803..a235a862eac61db6f767a82fde3d221237294cf6 100644 --- a/tests/models/diffusion/test_diffusion_train.py +++ b/tests/models/diffuser/test_diffusion_train.py @@ -20,7 +20,7 @@ import numpy as np from mindspore import Tensor, ops, amp, nn, jit from mindspore import dtype as mstype -from mindscience.models import DiffusionTransformer, DiffusionTrainer, DDPMScheduler, DDIMScheduler, DDPMPipeline, \ +from mindscience.diffuser import DiffusionTransformer, DiffusionTrainer, DDPMScheduler, DDIMScheduler, DDPMPipeline, \ DDIMPipeline, ConditionDiffusionTransformer from dataset import get_latent_dataset from ae import LATENT_DIM, generate_image diff --git a/tests/models/diffusion/utils.py b/tests/models/diffuser/utils.py similarity index 100% rename from tests/models/diffusion/utils.py rename to tests/models/diffuser/utils.py diff --git a/tests/models/transformer/attention_block.ckpt b/tests/models/transformer/attention_block.ckpt deleted file mode 100644 index b8726dcaf271d1f9de31e215f595997fb36595a4..0000000000000000000000000000000000000000 Binary files a/tests/models/transformer/attention_block.ckpt and /dev/null differ diff --git a/tests/models/transformer/attention_block_output.npy b/tests/models/transformer/attention_block_output.npy deleted file mode 100644 index b4507ee642080df94ad38323fcfc40f3ed5fbc41..0000000000000000000000000000000000000000 Binary files a/tests/models/transformer/attention_block_output.npy and /dev/null differ diff --git a/tests/models/transformer/grads.npz b/tests/models/transformer/grads.npz deleted file mode 100644 index 7426de8db2435ef77de457c6af1e74b434726562..0000000000000000000000000000000000000000 Binary files a/tests/models/transformer/grads.npz and /dev/null differ diff --git a/tests/models/transformer/input.npy b/tests/models/transformer/input.npy deleted file mode 100644 index 130198f08db5b536c0cad7ab3430350225b54306..0000000000000000000000000000000000000000 Binary files a/tests/models/transformer/input.npy and /dev/null differ diff --git a/tests/models/transformer/label.npy b/tests/models/transformer/label.npy deleted file mode 100644 index 2e11f8d6d8fb76fb07e8f9e22e02808d2b1a43bd..0000000000000000000000000000000000000000 Binary files a/tests/models/transformer/label.npy and /dev/null differ diff --git a/tests/models/transformer/mask.npy b/tests/models/transformer/mask.npy deleted file mode 100644 index 2458fe287ad7847e4bc0ca4d1faa4f5bb0f15c28..0000000000000000000000000000000000000000 Binary files a/tests/models/transformer/mask.npy and /dev/null differ diff --git a/tests/models/transformer/multihead.ckpt b/tests/models/transformer/multihead.ckpt deleted file mode 100644 index 259465f226ee8a03f2e9e25b22a67d33c0b218eb..0000000000000000000000000000000000000000 Binary files a/tests/models/transformer/multihead.ckpt and /dev/null differ diff --git a/tests/models/transformer/multihead_output.npy b/tests/models/transformer/multihead_output.npy deleted file mode 100644 index 27b6d8a4bbe3d0f0befe95e778211e36b4cc6900..0000000000000000000000000000000000000000 Binary files a/tests/models/transformer/multihead_output.npy and /dev/null differ diff --git a/tests/models/transformer/test_attention.py b/tests/models/transformer/test_attention.py index 637bbfd678b61d6edba79424c45110efe0d4659a..7c3a664a2b1037ae5b66ef84742427ec14b801f5 100644 --- a/tests/models/transformer/test_attention.py +++ b/tests/models/transformer/test_attention.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """attention testcase""" +# pylint: disable=C0413 import os import sys import pytest @@ -21,22 +22,35 @@ import numpy as np from mindspore import Tensor, ops, load_checkpoint, load_param_into_net, jit_class, context from mindspore import dtype as mstype -from mindscience.models import Attention, MultiHeadAttention, TransformerBlock, DropPath, ViT +from mindscience.models import ( + Attention, MultiHeadAttention, TransformerBlock, DropPath, VisionTransformer +) from mindscience.common import RelativeRMSELoss PROJECT_ROOT = os.path.abspath(os.path.join( os.path.dirname(__file__), "../../")) sys.path.append(PROJECT_ROOT) -from tools import compare_output, validate_checkpoint, validate_model_infer, validate_output_dtype -from tools import FP32_RTOL, FP32_ATOL, FP16_RTOL, FP16_ATOL +from tools import ( + compare_output, validate_checkpoint, validate_model_infer, validate_output_dtype, + FP16_ATOL, FP16_RTOL, FP32_ATOL, FP32_RTOL +) BATCH_SIZE, NUM_HEADS, SEQ_LEN, IN_CHANNELS = 2, 4, 15, 64 +DATA_PATH = '/home/workspace/mindspore_dataset/mindscience/attention' +MHA_CKPT_FILE = os.path.join(DATA_PATH, 'multihead.ckpt') +ATB_CKPT_FILE = os.path.join(DATA_PATH, 'attention_block.ckpt') +ATB_OUT_FILE = os.path.join(DATA_PATH, 'attention_block_output.npy') +MHO_OUT_FILE = os.path.join(DATA_PATH, 'multihead_output.npy') +INPUT_FILE = os.path.join(DATA_PATH, 'input.npy') +MASK_FILE = os.path.join(DATA_PATH, 'mask.npy') +LABEL_FILE = os.path.join(DATA_PATH, 'label.npy') +GRAD_FILE = os.path.join(DATA_PATH, 'grads.npz') def load_inputs(): - x = Tensor(np.load('input.npy').astype(np.float32)) - mask = Tensor(np.load('mask.npy').astype(np.int32)) + x = Tensor(np.load(INPUT_FILE).astype(np.float32)) + mask = Tensor(np.load(MASK_FILE).astype(np.int32)) return x, mask @@ -57,7 +71,8 @@ def test_attention_qkv(mode, compute_dtype): qkv = net.get_qkv(x) for tensor in qkv: assert tensor.dtype == compute_dtype - assert tensor.shape == (BATCH_SIZE, NUM_HEADS, SEQ_LEN, IN_CHANNELS//NUM_HEADS) + assert tensor.shape == (BATCH_SIZE, NUM_HEADS, + SEQ_LEN, IN_CHANNELS//NUM_HEADS) @pytest.mark.level0 @@ -73,9 +88,11 @@ def test_flash_attn(mode, fa_dtype): """ context.set_context(mode=mode) in_shape = (BATCH_SIZE, NUM_HEADS, SEQ_LEN, IN_CHANNELS//NUM_HEADS) - query, key, value = ops.randn(in_shape), ops.randn(in_shape), ops.randn(in_shape) + query, key, value = ops.randn(in_shape), ops.randn( + in_shape), ops.randn(in_shape) mask = ops.randint(0, 2, (SEQ_LEN, SEQ_LEN)) - net = MultiHeadAttention(IN_CHANNELS, NUM_HEADS, enable_flash_attn=True, fa_dtype=fa_dtype) + net = MultiHeadAttention(IN_CHANNELS, NUM_HEADS, + enable_flash_attn=True, fa_dtype=fa_dtype) output = net.attn(query, key, value, mask) assert output.dtype == fa_dtype assert output.shape == in_shape @@ -93,7 +110,8 @@ def test_multihead_fa(mode, fa_dtype): Expectation: success """ context.set_context(mode=mode) - net = MultiHeadAttention(IN_CHANNELS, NUM_HEADS, enable_flash_attn=True, fa_dtype=fa_dtype) + net = MultiHeadAttention(IN_CHANNELS, NUM_HEADS, + enable_flash_attn=True, fa_dtype=fa_dtype) in_shape = (BATCH_SIZE, SEQ_LEN, IN_CHANNELS) x = ops.randn(in_shape) mask = ops.randint(0, 2, (BATCH_SIZE, 1, SEQ_LEN, SEQ_LEN)) @@ -115,7 +133,8 @@ def test_fa_forward(mode, fa_dtype): """ context.set_context(mode=mode) net = MultiHeadAttention(IN_CHANNELS, NUM_HEADS, enable_flash_attn=False) - fa_net = MultiHeadAttention(IN_CHANNELS, NUM_HEADS, enable_flash_attn=True, fa_dtype=fa_dtype) + fa_net = MultiHeadAttention( + IN_CHANNELS, NUM_HEADS, enable_flash_attn=True, fa_dtype=fa_dtype) batch_size, seq_len = 256, 512 in_shape = (batch_size, seq_len, IN_CHANNELS) x = ops.randn(in_shape) @@ -191,8 +210,8 @@ def test_multihead_attention(mode): context.set_context(mode=mode) net = MultiHeadAttention(in_channels=IN_CHANNELS, num_heads=NUM_HEADS) x, mask = load_inputs() - validate_model_infer(net, (x, mask), './multihead.ckpt', - './multihead_output.npy', FP32_RTOL, FP32_ATOL) + validate_model_infer(net, (x, mask), MHA_CKPT_FILE, + MHO_OUT_FILE, FP32_RTOL, FP32_ATOL) @pytest.mark.level0 @@ -226,32 +245,32 @@ def test_attn_block(mode): context.set_context(mode=mode) net = TransformerBlock(in_channels=IN_CHANNELS, num_heads=NUM_HEADS) x, mask = load_inputs() - validate_model_infer(net, (x, mask), './attention_block.ckpt', - './attention_block_output.npy', FP32_RTOL, FP32_ATOL) + validate_model_infer(net, (x, mask), ATB_CKPT_FILE, + ATB_OUT_FILE, FP32_RTOL, FP32_ATOL) @pytest.mark.level0 @pytest.mark.platform_arm_ascend910b_training @pytest.mark.env_onecard @pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_vit_forward(mode): +def test_vision_transformer_forward(mode): """ - Feature: ViT + Feature: VisionTransformer Description: test forward result dtype Expectation: success """ context.set_context(mode=mode) x = ops.rand(32, 3, 192, 384) - model = ViT(in_channels=3, - out_channels=3, - encoder_depths=6, - encoder_embed_dim=768, - encoder_num_heads=12, - decoder_depths=6, - decoder_embed_dim=512, - decoder_num_heads=16, - compute_dtype=mstype.float32 - ) + model = VisionTransformer(in_channels=3, + out_channels=3, + encoder_depths=6, + encoder_embed_dim=768, + encoder_num_heads=12, + decoder_depths=6, + decoder_embed_dim=512, + decoder_num_heads=16, + compute_dtype=mstype.float32 + ) output = model(x) assert output.dtype == mstype.float32 assert output.shape == (32, 288, 768) @@ -303,14 +322,13 @@ def test_multihead_attention_grad(mode): Expectation: success """ context.set_context(mode=mode) - ckpt_path = './multihead.ckpt' model = MultiHeadAttention( IN_CHANNELS, NUM_HEADS, compute_dtype=mstype.float32) - params = load_checkpoint(ckpt_path) + params = load_checkpoint(MHA_CKPT_FILE) load_param_into_net(model, params) - input_data = Tensor(np.load('./input.npy')) - input_label = Tensor(np.load('./label.npy')) + input_data = Tensor(np.load(INPUT_FILE)) + input_label = Tensor(np.load(LABEL_FILE)) trainer = Trainer(model, RelativeRMSELoss()) @@ -324,7 +342,7 @@ def test_multihead_attention_grad(mode): _, grads = grad_fn(input_data, input_label) convert_grads = tuple(grad.asnumpy() for grad in grads) - with np.load('./grads.npz') as data: + with np.load(GRAD_FILE) as data: output_target = tuple(data[key] for key in data.files) validate_ans = compare_output( diff --git a/tests/models/transformer/test_vit.py b/tests/models/transformer/test_vit.py index 8041d568f726aa5b2351662b45ecddc9119f14d0..41f91621d9ddfbbd07f31ba20daed93463896866 100644 --- a/tests/models/transformer/test_vit.py +++ b/tests/models/transformer/test_vit.py @@ -19,36 +19,32 @@ import numpy as np from mindspore import Tensor, context from mindspore import dtype as mstype -from mindscience.models import ViT +from mindscience.models import VisionTransformer @pytest.mark.level0 @pytest.mark.platform_arm_ascend910b_training @pytest.mark.env_onecard -def test_vit_output(): +def test_vision_transformer(): """ - Feature: Test ViT network in platform gpu and ascend. + Feature: Test VisionTransformer network in platform gpu and ascend. Description: None. Expectation: Success or throw AssertionError. Need to adaptive 910B """ context.set_context(mode=context.GRAPH_MODE) input_tensor = Tensor(np.ones((32, 3, 192, 384)), mstype.float32) - print('input_tensor.shape: ', input_tensor.shape) - print('input_tensor.dtype: ', input_tensor.dtype) - model = ViT(in_channels=3, - out_channels=3, - encoder_depths=6, - encoder_embed_dim=768, - encoder_num_heads=12, - decoder_depths=6, - decoder_embed_dim=512, - decoder_num_heads=16, - ) + model = VisionTransformer(in_channels=3, + out_channels=3, + encoder_depths=6, + encoder_embed_dim=768, + encoder_num_heads=12, + decoder_depths=6, + decoder_embed_dim=512, + decoder_num_heads=16, + ) output_tensor = model(input_tensor) - print('output_tensor.shape: ', output_tensor.shape) - print('output_tensor.dtype: ', output_tensor.dtype) assert output_tensor.shape == (32, 288, 768) assert output_tensor.dtype == mstype.float32