diff --git a/mindscience/models/neural_operator/__init__.py b/mindscience/models/neural_operator/__init__.py index b7532f16dce56fd66336e3c02ec67899f8a21ced..6f2079798cc3bda90cc36beb42c5d90d709b7f40 100644 --- a/mindscience/models/neural_operator/__init__.py +++ b/mindscience/models/neural_operator/__init__.py @@ -19,6 +19,10 @@ from .kno2d import KNO2D from .pdenet import PDENet from .percnn import PeRCNN from .sno import SNO, SNO1D, SNO2D, SNO3D +from .ffno import FFNOBlocks, FFNO, FFNO1D, FFNO2D, FFNO3D __all__ = ["FNOBlocks", "FNO1D", "FNO2D", "FNO3D", "KNO1D", "KNO2D", "PDENet", "PeRCNN", + "FFNOBlocks", "FFNO", "FFNO1D", "FFNO2D", "FFNO3D", "SNO", "SNO1D", "SNO2D", "SNO3D"] + +__all__.sort() diff --git a/mindscience/models/neural_operator/afno2d.py b/mindscience/models/neural_operator/afno2d.py index b3e78739be2b807cf10118462d875176abeb6dca..88a8b74d6a87ca10688a391247a10fdd4d0dab1c 100644 --- a/mindscience/models/neural_operator/afno2d.py +++ b/mindscience/models/neural_operator/afno2d.py @@ -19,8 +19,7 @@ from mindspore import ops, nn, Tensor, Parameter from mindspore import dtype as mstype from mindspore.common.initializer import initializer, Normal, TruncatedNormal from mindspore.nn.probability.distribution import Bernoulli - -from .dft import dft2, idft2 +from ...sciops.fourier import RDFTn, IRDFTn class DropPath(nn.Cell): @@ -305,13 +304,13 @@ class AFNO2D(nn.Cell): self.h_size = h_size self.w_size = w_size - self.dft2_cell = dft2( + self.dft2_cell = RDFTn( shape=(h_size, w_size), dim=(-3, -2), - modes=(h_size // 2, w_size // 2 + 1), compute_dtype=compute_dtype + norm='ortho', compute_dtype=compute_dtype ) - self.idft2_cell = idft2( + self.idft2_cell = IRDFTn( shape=(h_size, w_size), dim=(-3, -2), - modes=(h_size // 2, w_size // 2 + 1), compute_dtype=compute_dtype + norm='ortho', compute_dtype=compute_dtype ) self.scale = 0.02 @@ -381,10 +380,7 @@ class AFNO2D(nn.Cell): h, w = self.h_size, self.w_size x = x.reshape(b, h, w, c) - x_re = x - x_im = ops.zeros_like(x_re) - - x_ft_re, x_ft_im = self.dft2_cell((x_re, x_im)) + x_ft_re, x_ft_im = self.dft2_cell(x) x_ft_re = x_ft_re.reshape( b, x_ft_re.shape[1], x_ft_re.shape[2], @@ -434,7 +430,7 @@ class AFNO2D(nn.Cell): o2_real = o2_real.reshape(b, o2_real.shape[1], o2_real.shape[2], c) o2_imag = o2_imag.reshape(b, o2_imag.shape[1], o2_imag.shape[2], c) - x, _ = self.idft2_cell((o2_real, o2_imag)) + x = self.idft2_cell(o2_real, o2_imag) x = x.reshape(b, n, c) return x + bias diff --git a/mindscience/models/neural_operator/dft.py b/mindscience/models/neural_operator/dft.py deleted file mode 100644 index 24fe07b9d4fee1fd02e762d5c5086fdd9405ffd0..0000000000000000000000000000000000000000 --- a/mindscience/models/neural_operator/dft.py +++ /dev/null @@ -1,769 +0,0 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -DFT -""" -import numpy as np -from scipy.linalg import dft - -import mindspore -import mindspore.common.dtype as mstype -from mindspore import nn, ops, Tensor, Parameter -from mindspore.common.initializer import Zero -from mindspore.ops import operations as P - -from ...utils.check_func import check_param_no_greater, check_param_value, check_param_type, check_param_even - - -class DFT1d(nn.Cell): - '''One dimensional Discrete Fourier Transformation''' - - def __init__(self, n, modes, last_index, idx=0, inv=False, compute_dtype=mindspore.float32): - super().__init__() - - self.n = n - self.dft_mat = dft(n, scale="sqrtn") - self.modes = modes - self.last_index = last_index - self.inv = inv - self.idx = idx - self.compute_dtype = compute_dtype - - self.dft_mode_mat_upper = self.dft_mat[:, :modes] - self.a_re_upper = Tensor( - self.dft_mode_mat_upper.real, dtype=compute_dtype) - self.a_im_upper = Tensor( - self.dft_mode_mat_upper.imag, dtype=compute_dtype) - - self.dft_mode_mat_lower = self.dft_mat[:, -modes:] - self.a_re_lower = Tensor( - self.dft_mode_mat_lower.real, dtype=compute_dtype) - self.a_im_lower = Tensor( - self.dft_mode_mat_lower.imag, dtype=compute_dtype) - self.concat = ops.Concat(axis=-1) - - if self.inv: - self.a_re_upper = self.a_re_upper.T - self.a_im_upper = -self.a_im_upper.T - if last_index: - if modes == n // 2 + 1: - self.dft_mat_res = self.dft_mat[:, -modes + 2:] - else: - self.dft_mat_res = self.dft_mat[:, -modes + 1:] - - mat = Tensor(np.zeros(n,), dtype=compute_dtype).reshape(n, 1) - self.a_re_res = mindspore.numpy.flip( - Tensor(self.dft_mat_res.real, dtype=compute_dtype), axis=-1) - self.a_im_res = mindspore.numpy.flip( - Tensor(self.dft_mat_res.imag, dtype=compute_dtype), axis=-1) - if modes == n // 2 + 1: - self.a_re_res = self.concat((mat, self.a_re_res, mat)) - self.a_im_res = self.concat((mat, self.a_im_res, mat)) - else: - self.a_re_res = self.concat((mat, self.a_re_res)) - self.a_im_res = self.concat((mat, self.a_im_res)) - - self.a_re_res = self.a_re_res.T - self.a_im_res = -self.a_im_res.T - else: - self.a_re_res = self.a_re_lower.T - self.a_im_res = -self.a_im_lower.T - - if (self.n - 2 * self.modes) > 0: - self.mat = Tensor(shape=(self.n - 2 * self.modes), - dtype=compute_dtype, init=Zero()) - - def swap_axes(self, x_re, x_im): - return x_re.swapaxes(-1, self.idx), x_im.swapaxes(-1, self.idx) - - def complex_matmul(self, x_re, x_im, a_re, a_im): - y_re = ops.matmul(x_re, a_re) - ops.matmul(x_im, a_im) - y_im = ops.matmul(x_im, a_re) + ops.matmul(x_re, a_im) - return y_re, y_im - - def construct(self, x): - """construct""" - x_re, x_im = x - x_re, x_im = P.Cast()(x_re, self.compute_dtype), P.Cast()(x_im, self.compute_dtype) - if not self.inv: - x_re, x_im = self.swap_axes(x_re, x_im) - y_re, y_im = self.complex_matmul( - x_re=x_re, x_im=x_im, a_re=self.a_re_upper, a_im=self.a_im_upper) - - if not self.last_index: - y_re2, y_im2 = self.complex_matmul( - x_re=x_re, x_im=x_im, a_re=self.a_re_lower, a_im=self.a_im_lower) - - if self.n == self.modes * 2: - y_re = self.concat((y_re, y_re2)) - y_im = self.concat((y_im, y_im2)) - else: - dims = x_re.shape[:-1] - length = len(dims) - mat = self.mat - for i in range(length - 1, -1, -1): - mat = mat.expand_dims(0).repeat(dims[i], 0) - y_re = self.concat((y_re, mat, y_re2)) - y_im = self.concat((y_im, mat, y_im2)) - - y_re, y_im = self.swap_axes(y_re, y_im) - return y_re, y_im - - x_re, x_im = self.swap_axes(x_re, x_im) - y_re, y_im = self.complex_matmul(x_re=x_re[..., :self.modes], x_im=x_im[..., :self.modes], - a_re=self.a_re_upper, - a_im=self.a_im_upper) - y_re, y_im = self.swap_axes(y_re, y_im) - - if self.last_index: - y_re_res, y_im_res = self.complex_matmul( - x_re=x_re, x_im=x_im, a_re=self.a_re_res, a_im=-self.a_im_res) - else: - y_re_res, y_im_res = self.complex_matmul(x_re=x_re[..., -self.modes:], x_im=x_im[..., -self.modes:], - a_re=self.a_re_res, a_im=self.a_im_res) - - y_re_res, y_im_res = self.swap_axes(y_re_res, y_im_res) - return y_re + y_re_res, y_im + y_im_res - - -class DFTn(nn.Cell): - '''N dimensional Discrete Fourier Transformation''' - - def __init__(self, shape, modes, dim=None, inv=False, compute_dtype=mindspore.float32): - super().__init__() - - if dim is None: - dim = range(len(shape)) - self.dft1_seq = nn.SequentialCell() - last_index = [False for _ in range(len(shape))] - last_index[-1] = True - for dim_id, idx in enumerate(dim): - self.dft1_seq.append( - DFT1d(n=shape[dim_id], modes=modes[dim_id], last_index=last_index[dim_id], idx=idx, inv=inv, - compute_dtype=compute_dtype)) - - def construct(self, x): - return self.dft1_seq(x) - - -def _dftn(shape, modes, dim=None, compute_dtype=mindspore.float32): - dftn_ = DFTn(shape=shape, modes=modes, dim=dim, - inv=False, compute_dtype=compute_dtype) - return dftn_ - - -def _idftn(shape, modes, dim=None, compute_dtype=mindspore.float32): - idftn_ = DFTn(shape=shape, modes=modes, dim=dim, - inv=True, compute_dtype=compute_dtype) - return idftn_ - - -def dft3(shape, modes, dim=(-3, -2, -1), compute_dtype=mindspore.float32): - r""" - Calculate three-dimensional discrete Fourier transform. Corresponding to the rfftn operator in torch. - - Args: - shape (tuple): Dimension of the input 'x'. - modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. - compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 3-D tuple of Tensor. It's a complex, - including x real and imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is no equal to 3. - - Examples: - >>> import numpy as np - >>> from mindspore import Tensor, ops - >>> import mindspore.common.dtype as mstype - >>> from mindflow.cell.neural_operators.dft import dft3 - >>> array = np.ones((6, 6, 6)) * np.arange(1, 7) - >>> x_re = Tensor(array, dtype=mstype.float32) - >>> x_im = x_re - >>> dft3_cell = dft3(shape=array.shape, modes=(2, 2, 2), compute_dtype=mstype.float32) - >>> ret, _ = dft3_cell((x_re, x_im)) - >>> print(ret) - [[[ 5.1439293e+01 -2.0076393e+01] - [ 7.9796671e-08 -1.9494735e-08] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 9.0537789e-08 1.0553553e-07] - [ 3.3567730e-07 1.0368046e-07]] - - [[ 4.7683722e-07 -3.1770034e-07] - [ 6.5267522e-15 -2.7775875e-15] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [-2.1755840e-15 -1.5215135e-15] - [ 3.6259736e-15 -4.0336615e-15]] - - [[ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00]] - - [[ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00]] - - [[ 1.1920930e-07 -5.1619136e-08] - [-3.6259733e-16 -1.0747753e-15] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 3.6259733e-16 -1.8129867e-16] - [ 3.6259733e-16 -1.4373726e-15]] - - [[ 5.9604650e-07 -2.5809570e-07] - [ 8.7023360e-15 -1.9812689e-15] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 2.9007787e-15 7.2519467e-16] - [ 8.7023360e-15 -1.7869532e-15]]] - - """ - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=tuple) - check_param_value(len(shape), "shape length", 3) - check_param_value(len(modes), "modes length", 3) - check_param_even(shape, "shape") - check_param_no_greater(modes[0], "mode1", shape[0] // 2) - check_param_no_greater(modes[1], "mode2", shape[1] // 2) - check_param_no_greater(modes[2], "mode3", shape[2] // 2 + 1) - return _dftn(shape, modes, dim=dim, compute_dtype=compute_dtype) - - -def idft3(shape, modes, dim=(-3, -2, -1), compute_dtype=mindspore.float32): - r""" - Calculate three-dimensional discrete Fourier transform. Corresponding to the irfftn operator in torch. - - Args: - shape (tuple): Dimension of the input 'x'. - modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. - compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 3-D tuple of Tensor. It's a complex, including x real and - imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is no equal to 3. - - Examples: - >>> import numpy as np - >>> from mindspore import Tensor, ops - >>> import mindspore.common.dtype as mstype - >>> from mindflow.cell.neural_operators.dft import idft3 - >>> array = np.ones((2, 2, 2)) * np.arange(1, 3) - >>> x_re = Tensor(array, dtype=mstype.float32) - >>> x_im = ops.zeros_like(x_re) - >>> idft3_cell = idft3(shape=(6, 6, 6), modes=(2, 2, 2), compute_dtype=mstype.float32) - >>> ret, _ = idft3_cell((x_re, x_im)) - >>> print(ret) - [[[ 5.44331074e+00 3.26598644e+00 -1.08866215e+00 -3.26598644e+00 -1.08866215e+00 3.26598644e+00] - [ 2.04124165e+00 2.04124165e+00 4.08248246e-01 -1.22474492e+00 -1.22474492e+00 4.08248365e-01] - [-6.80413842e-01 -1.22474492e+00 -6.80413783e-01 4.08248305e-01 9.52579379e-01 4.08248246e-01] - [ 0.00000000e+00 -2.30921616e-16 -2.30921616e-16 6.53092730e-32 2.30921616e-16 2.30921616e-16] - [-6.80413842e-01 4.08248246e-01 9.52579379e-01 4.08248305e-01 -6.80413783e-01 -1.22474492e+00] - [ 2.04124165e+00 4.08248365e-01 -1.22474492e+00 -1.22474492e+00 4.08248246e-01 2.04124165e+00]] - ...... - [[ 2.04124165e+00 4.08248544e-01 -1.22474492e+00 -1.22474504e+00 4.08248186e-01 2.04124165e+00] - [ 1.02062082e+00 6.12372518e-01 -2.04124182e-01 -6.12372518e-01 -2.04124182e-01 6.12372518e-01] - [-5.10310411e-01 -5.10310411e-01 -1.02062061e-01 3.06186229e-01 3.06186229e-01 -1.02062091e-01] - [-7.21630050e-17 -1.29893429e-16 -7.21630183e-17 4.32978030e-17 1.01028220e-16 4.32978163e-17] - [-6.08337416e-08 4.08248246e-01 4.08248305e-01 3.65002428e-08 -4.08248246e-01 -4.08248305e-01] - [ 5.10310471e-01 -3.06186140e-01 -7.14434564e-01 -3.06186318e-01 5.10310352e-01 9.18558717e-01]]] - - """ - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=tuple) - check_param_value(len(shape), "shape length", 3) - check_param_value(len(modes), "modes length", 3) - check_param_even(shape, "shape") - check_param_no_greater(modes[0], "mode1", shape[0] // 2) - check_param_no_greater(modes[1], "mode2", shape[1] // 2) - check_param_no_greater(modes[2], "mode3", shape[2] // 2 + 1) - return _idftn(shape, modes, dim=dim, compute_dtype=compute_dtype) - - -def dft2(shape, modes, dim=(-2, -1), compute_dtype=mindspore.float32): - """ - Calculate two-dimensional discrete Fourier transform. Corresponding to the rfft2 operator in torch. - - Args: - shape (tuple): Dimension of the input 'x'. - modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. - compute_dtype (:class:`mindspore.dtype`): The type of input tensor. Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 2-D tuple of Tensor. It's a complex, - including x real and imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is no equal to 2. - - Examples: - >>> import numpy as np - >>> from mindspore import Tensor, ops - >>> import mindspore.common.dtype as mstype - >>> from mindflow.cell.neural_operators.dft import dft2 - >>> array = np.ones((5, 5)) * np.arange(1, 6) - >>> x_re = Tensor(array, dtype=mstype.float32) - >>> x_im = x_re - >>> dft2_cell = dft2(shape=array.shape, modes=(2, 2), compute_dtype=mstype.float32) - >>> ret, _ = dft2_cell((x_re, x_im)) - >>> print(ret) - [[ 1.5000000e+01 -5.9409552e+00] - [-2.4656805e-07 7.6130398e-08] - [ 0.0000000e+00 0.0000000e+00] - [-1.9992007e-07 7.3572544e-08] - [-2.4656805e-07 7.6130398e-08]] - - """ - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=tuple) - check_param_value(len(shape), "shape length", 2) - check_param_value(len(modes), "modes length", 2) - check_param_even(shape, "shape") - check_param_no_greater(modes[0], "mode1", shape[0] // 2) - check_param_no_greater(modes[1], "mode2", shape[1] // 2 + 1) - return _dftn(shape, modes, dim=dim, compute_dtype=compute_dtype) - - -def idft2(shape, modes, dim=(-2, -1), compute_dtype=mindspore.float32): - """ - Calculate two-dimensional discrete Fourier transform. Corresponding to the irfft2 operator in torch. - - Args: - shape (tuple): Dimension of the input 'x'. - modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. - compute_dtype (:class:`mindspore.dtype`): The type of input tensor. Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 2-D tuple of Tensor. It's a complex, - including x real and imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is no equal to 2. - - Examples: - >>> import numpy as np - >>> from mindspore import Tensor, ops - >>> import mindspore.common.dtype as mstype - >>> from mindflow.cell.neural_operators.dft import idft2 - >>> array = np.ones((2, 2)) * np.arange(1, 3) - >>> x_re = Tensor(array, dtype=mstype.float32) - >>> x_im = ops.zeros_like(x_re) - >>> idft2_cell = idft2(shape=(5, 5), modes=(2, 2), compute_dtype=mstype.float32) - >>> ret, _ = idft2_cell((x_re, x_im)) - >>> print(ret) - [[ 3.9999998 1.7888544 -1.7888546 -1.7888546 1.7888544 ] - [ 0.80901694 0.80901694 -0.08541022 -0.6381966 -0.08541021] - [-0.30901706 -0.8618034 -0.30901694 0.5854102 0.5854101 ] - [-0.30901706 0.5854101 0.5854102 -0.30901694 -0.8618034 ] - [ 0.80901694 -0.08541021 -0.6381966 -0.08541022 0.80901694]] - - """ - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=tuple) - check_param_value(len(shape), "shape length", 2) - check_param_value(len(modes), "modes length", 2) - check_param_even(shape, "shape") - check_param_no_greater(modes[0], "mode1", shape[0] // 2) - check_param_no_greater(modes[1], "mode2", shape[1] // 2 + 1) - return _idftn(shape, modes, dim=dim, compute_dtype=compute_dtype) - - -def dft1(shape, modes, dim=(-1,), compute_dtype=mindspore.float32): - """ - Calculate one-dimensional discrete Fourier transform. Corresponding to the rfft operator in torch. - - Args: - shape (tuple): Dimension of the input 'x'. - modes (int): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. - compute_dtype (:class:`mindspore.dtype`): The type of input tensor. - Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 2-D tuple of Tensor. It's a complex, - including x real and imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is no equal to 1. - - Examples: - >>> from mindspore import Tensor, ops - >>> import mindspore.common.dtype as mstype - >>> from mindflow.cell.neural_operators.dft import dft1 - >>> array = [i for i in range(5)] - >>> x_re = Tensor(array, dtype=mstype.float32) - >>> x_im = ops.zeros_like(x_re) - >>> dft1_cell = dft1(shape=(len(x_re),), modes=2, compute_dtype=mstype.float32) - >>> ret, _ = dft1_cell((x_re, x_im)) - >>> print(ret) - [ 4.4721355 -1.1180341] - - """ - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=int) - check_param_value(len(shape), "shape length", 1) - check_param_even(shape, "shape") - check_param_no_greater(modes, "mode1", shape[0] // 2 + 1) - modes = (modes,) - return _dftn(shape, modes, dim=dim, compute_dtype=compute_dtype) - - -def idft1(shape, modes, dim=(-1,), compute_dtype=mindspore.float32): - """ - Calculate one-dimensional discrete Fourier transform. Corresponding to the irfft operator in torch. - - Args: - shape (tuple): Dimension of the input 'x'. - modes (int): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. - compute_dtype (:class:`mindspore.dtype`): The type of input tensor. Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 2-D tuple of Tensor. It's a complex, - including x real and imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is no equal to 1. - - Examples: - >>> from mindspore import Tensor, ops - >>> import mindspore.common.dtype as mstype - >>> from mindflow.cell.neural_operators.dft import idft1 - >>> array = [i for i in range(2)] - >>> x_re = Tensor(array, dtype=mstype.float32) - >>> x_im = x_re - >>> idft1_cell = idft1(shape=(len(x_re),), modes=2, compute_dtype=mstype.float32) - >>> ret, _ = idft1_cell((x_re, x_im)) - >>> print(ret) - [ 0.8944272 -0.5742576 -1.2493379 -0.19787574 1.127044 ] - - """ - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=int) - check_param_value(len(shape), "shape length", 1) - check_param_even(shape, "shape") - check_param_no_greater(modes, "mode1", shape[0] // 2 + 1) - modes = (modes,) - return _idftn(shape, modes, dim=dim, compute_dtype=compute_dtype) - - -class SpectralConvDft(nn.Cell): - """Base Class for Fourier Layer, including DFT, linear transform, and Inverse DFT""" - - def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - if isinstance(n_modes, int): - n_modes = [n_modes] - self.n_modes = n_modes - if isinstance(resolutions, int): - resolutions = [resolutions] - self.resolutions = resolutions - if len(self.n_modes) != len(self.resolutions): - raise ValueError( - "The dimension of n_modes should be equal to that of resolutions, \ - but got dimension of n_modes {} and dimension of resolutions {}".format(len(self.n_modes), - len(self.resolutions))) - self.compute_dtype = compute_dtype - - def construct(self, x: Tensor): - raise NotImplementedError() - - def _einsum(self, inputs, weights): - weights = weights.expand_dims(0) - inputs = inputs.expand_dims(2) - out = inputs * weights - return out.sum(1) - - -class SpectralConv1dDft(SpectralConvDft): - """1D Fourier Layer. It does DFT, linear transform, and Inverse DFT.""" - - def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): - super().__init__(in_channels, out_channels, n_modes, resolutions) - self._scale = (1. / (self.in_channels * self.out_channels)) - w_re = Tensor(self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0]), - dtype=mstype.float32) - w_im = Tensor(self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0]), - dtype=mstype.float32) - self._w_re = Parameter(w_re, requires_grad=True) - self._w_im = Parameter(w_im, requires_grad=True) - self._dft1_cell = dft1(shape=(self.resolutions[0],), modes=self.n_modes[0], compute_dtype=self.compute_dtype) - self._idft1_cell = idft1(shape=(self.resolutions[0],), modes=self.n_modes[0], compute_dtype=self.compute_dtype) - - def construct(self, x: Tensor): - x_re = x - x_im = ops.zeros_like(x_re) - x_ft_re, x_ft_im = self._dft1_cell((x_re, x_im)) - w_re = P.Cast()(self._w_re, self.compute_dtype) - w_im = P.Cast()(self._w_im, self.compute_dtype) - out_ft_re = self._einsum(x_ft_re[:, :, :self.n_modes[0]], w_re) - self._einsum(x_ft_im[:, :, :self.n_modes[0]], - w_im) - out_ft_im = self._einsum(x_ft_re[:, :, :self.n_modes[0]], w_im) + self._einsum(x_ft_im[:, :, :self.n_modes[0]], - w_re) - - x, _ = self._idft1_cell((out_ft_re, out_ft_im)) - - return x - - -class SpectralConv2dDft(SpectralConvDft): - """2D Fourier Layer. It does DFT, linear transform, and Inverse DFT.""" - - def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): - super().__init__(in_channels, out_channels, n_modes, resolutions) - self._scale = (1. / (self.in_channels * self.out_channels)) - w_re1 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), - dtype=self.compute_dtype) - w_im1 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), - dtype=self.compute_dtype) - w_re2 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), - dtype=self.compute_dtype) - w_im2 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), - dtype=self.compute_dtype) - - self._w_re1 = Parameter(w_re1, requires_grad=True) - self._w_im1 = Parameter(w_im1, requires_grad=True) - self._w_re2 = Parameter(w_re2, requires_grad=True) - self._w_im2 = Parameter(w_im2, requires_grad=True) - - self._dft2_cell = dft2(shape=(self.resolutions[0], self.resolutions[1]), - modes=(self.n_modes[0], self.n_modes[1]), compute_dtype=self.compute_dtype) - self._idft2_cell = idft2(shape=(self.resolutions[0], self.resolutions[1]), - modes=(self.n_modes[0], self.n_modes[1]), compute_dtype=self.compute_dtype) - self._mat = Tensor(shape=(1, self.out_channels, self.resolutions[1] - 2 * self.n_modes[0], self.n_modes[1]), - dtype=self.compute_dtype, init=Zero()) - self._concat = ops.Concat(-2) - - def construct(self, x: Tensor): - x_re = x - x_im = ops.zeros_like(x_re) - x_ft_re, x_ft_im = self._dft2_cell((x_re, x_im)) - - out_ft_re1 = self._einsum( - x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1 - ) - self._einsum( - x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1 - ) - out_ft_im1 = self._einsum( - x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1 - ) + self._einsum( - x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1 - ) - - out_ft_re2 = self._einsum( - x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2 - ) - self._einsum( - x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2 - ) - out_ft_im2 = self._einsum( - x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2 - ) + self._einsum( - x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2 - ) - - batch_size = x.shape[0] - mat = self._mat.repeat(batch_size, 0) - out_re = self._concat((out_ft_re1, mat, out_ft_re2)) - out_im = self._concat((out_ft_im1, mat, out_ft_im2)) - - x, _ = self._idft2_cell((out_re, out_im)) - - return x - - -class SpectralConv3dDft(SpectralConvDft): - """3D Fourier layer. It does DFT, linear transform, and Inverse DFT.""" - - def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): - super().__init__(in_channels, out_channels, n_modes, resolutions) - self._scale = (1 / (self.in_channels * self.out_channels)) - - w_re1 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_im1 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_re2 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_im2 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_re3 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_im3 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_re4 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_im4 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - - self._w_re1 = Parameter(w_re1, requires_grad=True) - self._w_im1 = Parameter(w_im1, requires_grad=True) - self._w_re2 = Parameter(w_re2, requires_grad=True) - self._w_im2 = Parameter(w_im2, requires_grad=True) - self._w_re3 = Parameter(w_re3, requires_grad=True) - self._w_im3 = Parameter(w_im3, requires_grad=True) - self._w_re4 = Parameter(w_re4, requires_grad=True) - self._w_im4 = Parameter(w_im4, requires_grad=True) - - self._dft3_cell = dft3(shape=(self.resolutions[0], self.resolutions[1], self.resolutions[2]), - modes=(self.n_modes[0], self.n_modes[1], self.n_modes[2]), - compute_dtype=self.compute_dtype) - self._idft3_cell = idft3(shape=(self.resolutions[0], self.resolutions[1], self.resolutions[2]), - modes=(self.n_modes[0], self.n_modes[1], self.n_modes[2]), - compute_dtype=self.compute_dtype) - self._mat_x = Tensor( - shape=(1, self.out_channels, self.resolutions[0] - 2 * self.n_modes[0], self.n_modes[1], self.n_modes[2]), - dtype=self.compute_dtype, init=Zero()) - self._mat_y = Tensor( - shape=(1, self.out_channels, self.resolutions[0], self.resolutions[1] - 2 * self.n_modes[1], - self.n_modes[2]), - dtype=self.compute_dtype, init=Zero()) - self._concat = ops.Concat(-2) - - def construct(self, x: Tensor): - x_re = x - x_im = ops.zeros_like(x_re) - x_ft_re, x_ft_im = self._dft3_cell((x_re, x_im)) - - out_ft_re1 = self._einsum( - x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], - self._w_re1 - ) - self._einsum( - x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], - self._w_im1 - ) - out_ft_im1 = self._einsum( - x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], - self._w_im1 - ) + self._einsum( - x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], - self._w_re1 - ) - out_ft_re2 = self._einsum( - x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], - self._w_re2 - ) - self._einsum( - x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], - self._w_im2 - ) - out_ft_im2 = self._einsum( - x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], - self._w_im2 - ) + self._einsum( - x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], - self._w_re2 - ) - out_ft_re3 = self._einsum( - x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], - self._w_re3 - ) - self._einsum( - x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], - self._w_im3 - ) - out_ft_im3 = self._einsum( - x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], - self._w_im3 - ) + self._einsum( - x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], - self._w_re3 - ) - out_ft_re4 = self._einsum( - x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], - self._w_re4 - ) - self._einsum( - x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], - self._w_im4 - ) - out_ft_im4 = self._einsum( - x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], - self._w_im4 - ) + self._einsum( - x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], - self._w_re4 - ) - - batch_size = x.shape[0] - mat_x = self._mat_x.repeat(batch_size, 0) - mat_y = self._mat_y.repeat(batch_size, 0) - - out_re1 = ops.concat((out_ft_re1, mat_x, out_ft_re2), -3) - out_im1 = ops.concat((out_ft_im1, mat_x, out_ft_im2), -3) - - out_re2 = ops.concat((out_ft_re3, mat_x, out_ft_re4), -3) - out_im2 = ops.concat((out_ft_im3, mat_x, out_ft_im4), -3) - out_re = ops.concat((out_re1, mat_y, out_re2), -2) - out_im = ops.concat((out_im1, mat_y, out_im2), -2) - x, _ = self._idft3_cell((out_re, out_im)) - - return x diff --git a/mindscience/models/neural_operator/ffno.py b/mindscience/models/neural_operator/ffno.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ae17e31525da95478ebe6788a3fafbc1bbc9a9 --- /dev/null +++ b/mindscience/models/neural_operator/ffno.py @@ -0,0 +1,792 @@ +'''' +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +# pylint: disable=W0235 + +from mindspore import nn, ops, Tensor, Parameter, ParameterTuple, mint +from mindspore.common.initializer import XavierNormal, initializer +import mindspore.common.dtype as mstype + +from .ffno_sp import SpectralConv1d, SpectralConv2d, SpectralConv3d +from ...common.math import get_grid_1d, get_grid_2d, get_grid_3d +from ...utils.check_func import check_param_type + + +class FFNOBlocks(nn.Cell): + r""" + The FFNOBlock, which usually accompanied by a Lifting Layer ahead and a Projection Layer behind, + is a part of Factorized Fourier Neural Operator. It contains a Factorized Fourier Layer. The details can be found + in `A. Tran, A. Mathews, et. al: FACTORIZED FOURIER NEURAL OPERATORS `_. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer. + resolutions (Union[int, list(int)]): The resolutions of the input tensor. + factor (int): The number of neurons in the hidden layer of a feedforward network. Default: ``1``. + n_ff_layers (int): The number of layers (hidden layers) in the feedforward neural network. Default: ``2``. + ff_weight_norm (bool): Whether to do weight normalization in feedforward or not. Used as a reserved function + interface, the weight normalization is not supported in feedforward. Default: ``False``. + layer_norm (bool): Whether to do layer normalization in feedforward or not. Default: ``True``. + dropout (float): The value of percent be dropped when applying dropout regularization. Default: ``0.0``. + r_padding (int): The number used to pad a tensor on the right in a certain dimension. Pad the domain if + input is non-periodic. Default: ``0``. + use_fork (bool): Whether to perform forecasting or not. Default: ``False``. + forecast_ff (Feedforward): The feedforward network of generating "backcast" output. Default: ``None``. + backcast_ff (Feedforward): The feedforward network of generating "forecast" output. Default: ``None``. + fourier_weight (ParameterTuple[Parmemter]): The fourier weight for transforming data in the frequency + domain, with a ParameterTuple of Parmemter with a length of 2N. + + - Even indices (0, 2, 4, ...) represent the real parts of the complex parmemter. + - Odd indices (1, 3, 5, ...) represent the imaginary parts of the complex parmemter. + - Default: ``None``, meaning no data is provided. + dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConv. Default: ``mstype.float32``. + ffno_compute_dtype (dtype.Number): The computation type of MLP in ffno skip. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for the GPU backend, + mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, in\_channels, resolution)`. + + Outputs: + Tensor, the output of this FFNOBlocks. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, out\_channels, resolution)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `out_channels` is not an int. + TypeError: If `factor` is not an int. + TypeError: If `n_ff_layers` is not an int. + TypeError: If `ff_weight_norm` is not a Boolean value. + ValueError: If `ff_weight_norm` is not ``False``. + TypeError: If `layer_norm` is not a Boolean value. + TypeError: If `dropout` is not a float. + TypeError: If `r_padding` is not an int. + TypeError: If `use_fork` is not a Boolean value. + + Supported Platforms: + ``Ascend`` + + Examples:` + >>> import numpy as np + >>> from mindspore import Tensor + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell.neural_operators import FFNOBlocks + >>> data = Tensor(np.ones([2, 128, 128, 2]), mstype.float32) + >>> net = FFNOBlocks(in_channels=2, out_channels=2, n_modes=[20, 20], resolutions=[128, 128]) + >>> out0, out1 = net(data) + >>> print(data.shape, out0.shape, out1.shape) + (2, 128, 128, 2) (2, 128, 128, 2) (2, 128, 128, 2) + """ + + def __init__(self, + in_channels, + out_channels, + n_modes, + resolutions, + factor=1, + n_ff_layers=2, + ff_weight_norm=False, + layer_norm=True, + dropout=0.0, + r_padding=0, + use_fork=False, + forecast_ff=None, + backcast_ff=None, + fourier_weight=None, + dft_compute_dtype=mstype.float32, + ffno_compute_dtype=mstype.float32 + ): + super().__init__() + check_param_type(in_channels, "in_channels", data_type=int) + check_param_type(out_channels, "out_channels", data_type=int) + self.in_channels = in_channels + self.out_channels = out_channels + self.n_modes, self.resolutions = validate_and_expand_dimensions( + 1, n_modes, resolutions, False) + + check_param_type(factor, "factor", data_type=int) + check_param_type(n_ff_layers, "n_ff_layers", data_type=int) + check_param_type(ff_weight_norm, "ff_weight_norm", data_type=bool) + check_param_type(layer_norm, "layer_norm", data_type=bool) + check_param_type(dropout, "dropout", data_type=float) + check_param_type(r_padding, 'r_padding', data_type=int) + + if ff_weight_norm: + raise ValueError( + f"The weight normalization is not supported in feedforward\ + but got value of ff_weight_norm {ff_weight_norm}") + + if r_padding < 0: + raise ValueError( + f"The right padding value cannot be negative\ + but got value of r_padding {r_padding}") + + check_param_type(use_fork, "use_fork", data_type=bool) + self.factor = factor + self.ff_weight_norm = ff_weight_norm + self.n_ff_layers = n_ff_layers + self.layer_norm = layer_norm + self.dropout = dropout + self.r_padding = r_padding + self.use_fork = use_fork + self.forecast_ff = forecast_ff + self.backcast_ff = backcast_ff + self.fourier_weight = fourier_weight + self.dft_compute_dtype = dft_compute_dtype + self.ffno_compute_dtype = ffno_compute_dtype + + if len(self.resolutions) == 1: + spectral_conv = SpectralConv1d + elif len(self.resolutions) == 2: + spectral_conv = SpectralConv2d + elif len(self.resolutions) == 3: + spectral_conv = SpectralConv3d + else: + raise ValueError( + f"The length of input resolutions dimensions should be in [1, 2, 3], but got: {len(self.resolutions)}") + + self._convs = spectral_conv(self.in_channels, + self.out_channels, + self.n_modes, + self.resolutions, + forecast_ff=self.forecast_ff, + backcast_ff=self.backcast_ff, + fourier_weight=self.fourier_weight, + factor=self.factor, + ff_weight_norm=self.ff_weight_norm, + n_ff_layers=self.n_ff_layers, + layer_norm=self.layer_norm, + use_fork=self.use_fork, + dropout=self.dropout, + r_padding=self.r_padding, + compute_dtype=self.dft_compute_dtype, + filter_mode='full') + + def construct(self, x: Tensor): + b, _ = self._convs(x) + x = ops.add(x, b) + return x, b + + +def validate_and_expand_dimensions(dim, n_modes, resolutions, is_validate_dim=True): + """validate and expand the dimension of inputs""" + if isinstance(n_modes, int): + n_modes = [n_modes] * dim + if isinstance(resolutions, int): + resolutions = [resolutions] * dim + + n_modes_num = len(n_modes) + resolutions_num = len(resolutions) + + if is_validate_dim: + if n_modes_num != dim: + raise ValueError( + f"The dimension of n_modes should be equal to {dim} when using FFNO{dim}D\ + but got dimension of n_modes {n_modes_num}") + if resolutions_num != dim: + raise ValueError( + f"The dimension of resolutions should be equal to {dim} when using FFNO{dim}D\ + but got dimension of resolutions {resolutions_num}") + if n_modes_num != resolutions_num: + raise ValueError( + f"The dimension of n_modes should be equal to that of resolutions\ + but got dimension of n_modes {n_modes_num} and dimension of resolutions {resolutions_num}") + + return n_modes, resolutions + + +class FFNO(nn.Cell): + r""" + The FFNO base class, which usually contains a Lifting Layer, a Factorized Fourier Block Layer and a Projection + Layer. The details can be found in + `A. Tran, A. Mathews, et. al: FACTORIZED FOURIER NEURAL OPERATORS `_. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer. + resolutions (Union[int, list(int)]): The resolutions of the input tensor. + hidden_channels (int): The number of channels of the FNOBlock input and output. Default: ``20``. + lifting_channels (int): The number of channels of the lifting layer mid channels. Default: None. + projection_channels (int): The number of channels of the projection layer mid channels. Default: ``128``. + factor (int): The number of neurons in the hidden layer of a feedforward network. Default: ``1``. + n_layers (int): The number that Fourier Layer nests. Default: ``4``. + n_ff_layers (int): The number of layers (hidden layers) in the feedforward neural network. Default: ``2``. + ff_weight_norm (bool): Whether to do weight normalization in feedforward or not. Used as a reserved function + interface, the weight normalization is not supported in feedforward. Default: ``False``. + layer_norm (bool): Whether to do layer normalization in feedforward or not. Default: ``True``. + share_weight (bool): Whether to share weights between SpectralConv layers or not. Default: ``False``. + r_padding (int): The number used to pad a tensor on the right in a certain dimension. Pad the domain if + input is non-periodic. Default: ``0``. + data_format (str): The input data channel sequence. Default: ``channels_last``. + positional_embedding (bool): Whether to embed positional information or not. Default: ``True``. + dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``. + ffno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for + the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. + + Outputs: + Tensor, the output of this FNOBlocks. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, out\_channels)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `out_channels` is not an int. + TypeError: If `hidden_channels` is not an int. + TypeError: If `lifting_channels` is not an int. + TypeError: If `projection_channels` is not an int. + TypeError: If `factor` is not an int. + TypeError: If `n_layers` is not an int. + TypeError: If `n_ff_layers` is not an int. + TypeError: If `ff_weight_norm` is not a Boolean value. + ValueError: If `ff_weight_norm` is not ``False``. + TypeError: If `layer_norm` is not a Boolean value. + TypeError: If `share_weight` is not a Boolean value. + TypeError: If `r_padding` is not an int. + TypeError: If `data_format` is not a str. + TypeError: If `positional_embedding` is not a bool. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell.neural_operators.ffno import FFNO + >>> data = Tensor(np.ones([2, 128, 128, 2]), mstype.float32) + >>> net = FFNO(in_channels=2, out_channels=2, n_modes=[20, 20], resolutions=[128, 128]) + >>> out = net(data) + >>> print(data.shape, out.shape) + (2, 128, 128, 2) (2, 128, 128, 2) + """ + + def __init__( + self, + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels=20, + lifting_channels=None, + projection_channels=128, + factor=1, + n_layers=4, + n_ff_layers=2, + ff_weight_norm=False, + layer_norm=True, + share_weight=False, + r_padding=0, + data_format="channels_last", + positional_embedding=True, + dft_compute_dtype=mstype.float32, + ffno_compute_dtype=mstype.float16 + ): + super().__init__() + check_param_type(in_channels, "in_channels", data_type=int, exclude_type=bool) + check_param_type(out_channels, "out_channels", data_type=int, exclude_type=bool) + check_param_type(hidden_channels, "hidden_channels", data_type=int, exclude_type=bool) + check_param_type(factor, "factor", data_type=int, exclude_type=bool) + check_param_type(n_layers, "n_layers", data_type=int, exclude_type=bool) + check_param_type(n_ff_layers, "n_ff_layers", data_type=int, exclude_type=bool) + check_param_type(ff_weight_norm, "ff_weight_norm", data_type=bool, exclude_type=str) + check_param_type(layer_norm, "layer_norm", data_type=bool, exclude_type=str) + check_param_type(share_weight, "share_weight", data_type=bool, exclude_type=str) + check_param_type(r_padding, "r_padding", data_type=int, exclude_type=bool) + check_param_type(data_format, "data_format", data_type=str, exclude_type=bool) + check_param_type(positional_embedding, "positional_embedding", data_type=bool, exclude_type=str) + + if ff_weight_norm: + raise ValueError(f"The weight normalization is not supported in feedforward\ + but got value of ff_weight_norm {ff_weight_norm}") + if r_padding < 0: + raise ValueError(f"The right padding value cannot be negative but got value of r_padding {r_padding}") + + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.lifting_channels = lifting_channels + self.projection_channels = projection_channels + self.n_modes, self.resolutions = validate_and_expand_dimensions(1, n_modes, resolutions, False) + self.n_layers = n_layers + self.r_padding = r_padding + self.data_format = data_format + self.positional_embedding = positional_embedding + if self.positional_embedding: + self.in_channels += len(self.resolutions) + self.dft_compute_dtype = dft_compute_dtype + self.ffno_compute_dtype = ffno_compute_dtype + self._concat = ops.Concat(axis=-1) + self._positional_embedding = self._transpose(len(self.resolutions)) + self._padding = self._pad(len(self.resolutions)) + self._lifting = self.lift_channels( + self.in_channels, self.hidden_channels, self.lifting_channels, self.ffno_compute_dtype) + + self.fourier_weight = None + if share_weight: + param_list = [] + for i, n_mode in enumerate(self.n_modes): + weight_shape = [hidden_channels, hidden_channels, n_mode] + w_re = Parameter(initializer(XavierNormal(), weight_shape, mstype.float32), name=f'base_w_re_{i}', + requires_grad=True) + w_im = Parameter(initializer(XavierNormal(), weight_shape, mstype.float32), name=f'base_w_im_{i}', + requires_grad=True) + param_list.append(w_re) + param_list.append(w_im) + + self.fourier_weight = ParameterTuple([param for param in param_list]) + + self.factor = factor + self.ff_weight_norm = ff_weight_norm + self.n_ff_layers = n_ff_layers + self.layer_norm = layer_norm + + self._ffno_blocks = nn.CellList([FFNOBlocks(in_channels=self.hidden_channels, + out_channels=self.hidden_channels, + n_modes=self.n_modes, + resolutions=self.resolutions, + factor=self.factor, + n_ff_layers=self.n_ff_layers, + ff_weight_norm=self.ff_weight_norm, + layer_norm=self.layer_norm, + dropout=0.0, r_padding=self.r_padding, + use_fork=False, forecast_ff=None, backcast_ff=None, + fourier_weight=self.fourier_weight, + dft_compute_dtype=self.dft_compute_dtype + ) for _ in range(self.n_layers)]) + + self._projection = self.lift_channels( + self.hidden_channels, self.out_channels, self.projection_channels, self.ffno_compute_dtype) + + def lift_channels(self, in_c, out_c, mid_c=0, compute_dtype=mstype.float32): + if mid_c: + return nn.SequentialCell([ + nn.Dense(in_c, mid_c, has_bias=True).to_float(compute_dtype), + nn.Dense(mid_c, out_c, has_bias=True).to_float(compute_dtype) + ]) + return nn.SequentialCell(nn.Dense(in_c, out_c, has_bias=True).to_float(compute_dtype)) + + def construct(self, x: Tensor): + """construct""" + batch_size = x.shape[0] + grid = mint.repeat_interleave(self._positional_embedding.astype(x.dtype), repeats=batch_size, dim=0) + + if self.data_format != "channels_last": + x = ops.movedim(x, 1, -1) + + if self.positional_embedding: + x = self._concat((x, grid)) + + x = self._lifting(x) + if self.r_padding != 0: + x = ops.movedim(x, -1, 1) + x = ops.pad(x, self._padding) + x = ops.movedim(x, 1, -1) + + b = Tensor(0, dtype=mstype.float32) + for block in self._ffno_blocks: + x, b = block(x) + + if self.r_padding != 0: + b = self._remove_padding(len(self.resolutions), b) + + x = self._projection(b) + + if self.data_format != "channels_last": + x = ops.movedim(x, -1, 1) + + return x + + def _transpose(self, n_dim): + """transpose tensor""" + if n_dim == 1: + positional_embedding = Tensor(get_grid_1d(resolution=self.resolutions)) + elif n_dim == 2: + positional_embedding = Tensor(get_grid_2d(resolution=self.resolutions)) + elif n_dim == 3: + positional_embedding = Tensor(get_grid_3d(resolution=self.resolutions)) + else: + raise ValueError(f"The length of input resolutions dimensions should be in [1, 2, 3], but got: {n_dim}") + return positional_embedding + + def _pad(self, n_dim): + """pad the domain if input is non-periodic""" + if not n_dim in {1, 2, 3}: + raise ValueError(f"The length of input resolutions dimensions should be in [1, 2, 3], but got: {n_dim}") + return n_dim * [0, self.r_padding] + + def _remove_padding(self, n_dim, b_input): + """remove pad domain""" + if n_dim == 1: + b = b_input[..., :-self.r_padding, :] + elif n_dim == 2: + b = b_input[..., :-self.r_padding, :-self.r_padding, :] + elif n_dim == 3: + b = b_input[..., :-self.r_padding, :-self.r_padding, :-self.r_padding, :] + else: + raise ValueError(f"The length of input resolutions dimensions should be in [1, 2, 3], but got: {n_dim}") + return b + + +class FFNO1D(FFNO): + r""" + The 1D Factorized Fourier Neural Operator, which usually contains a Lifting Layer, + a Factorized Fourier Block Layer and a Projection Layer. The details can be found in + `A. Tran, A. Mathews, et. al: FACTORIZED FOURIER NEURAL OPERATORS `_. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer. + resolutions (Union[int, list(int)]): The resolutions of the input tensor. + hidden_channels (int): The number of channels of the FNOBlock input and output. Default: ``20``. + lifting_channels (int): The number of channels of the lifting layer mid channels. Default: None. + projection_channels (int): The number of channels of the projection layer mid channels. Default: ``128``. + factor (int): The number of neurons in the hidden layer of a feedforward network. Default: ``1``. + n_layers (int): The number that Fourier Layer nests. Default: ``4``. + n_ff_layers (int): The number of layers (hidden layers) in the feedforward neural network. Default: ``2``. + ff_weight_norm (bool): Whether to do weight normalization in feedforward or not. Used as a reserved function + interface, the weight normalization is not supported in feedforward. Default: ``False``. + layer_norm (bool): Whether to do layer normalization in feedforward or not. Default: ``True``. + share_weight (bool): Whether to share weights between SpectralConv layers or not. Default: ``False``. + r_padding (int): The number used to pad a tensor on the right in a certain dimension. Default: ``0``. + data_format (str): The input data channel sequence. Default: ``channels_last``. + positional_embedding (bool): Whether to embed positional information or not. Default: ``True``. + dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``. + ffno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for + the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. + + Outputs: + Tensor, the output of this FNOBlocks. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, out\_channels)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `out_channels` is not an int. + TypeError: If `hidden_channels` is not an int. + TypeError: If `lifting_channels` is not an int. + TypeError: If `projection_channels` is not an int. + TypeError: If `factor` is not an int. + TypeError: If `n_layers` is not an int. + TypeError: If `n_ff_layers` is not an int. + TypeError: If `ff_weight_norm` is not a Boolean value. + ValueError: If `ff_weight_norm` is not ``False``. + TypeError: If `layer_norm` is not a Boolean value. + TypeError: If `share_weight` is not a Boolean value. + TypeError: If `r_padding` is not an int. + TypeError: If `data_format` is not a str. + TypeError: If `positional_embedding` is not a bool. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> import mindflow + >>> from mindspore import Tensor + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell import FFNO1D + >>> data = Tensor(np.ones([2, 128, 3]), mstype.float32) + >>> net = FFNO1D(in_channels=3, out_channels=3, n_modes=[20], resolutions=[128]) + >>> out = net(data) + >>> print(data.shape, out.shape) + (2, 128, 3) (2, 128, 3) + """ + + def __init__( + self, + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels=20, + lifting_channels=None, + projection_channels=128, + factor=1, + n_layers=4, + n_ff_layers=2, + ff_weight_norm=False, + layer_norm=True, + share_weight=False, + r_padding=0, + data_format="channels_last", + positional_embedding=True, + dft_compute_dtype=mstype.float32, + ffno_compute_dtype=mstype.float16 + ): + n_modes, resolutions = validate_and_expand_dimensions(1, n_modes, resolutions) + super().__init__( + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels, + lifting_channels, + projection_channels, + factor, + n_layers, + n_ff_layers, + ff_weight_norm, + layer_norm, + share_weight, + r_padding, + data_format, + positional_embedding, + dft_compute_dtype, + ffno_compute_dtype + ) + + +class FFNO2D(FFNO): + r""" + The 2D Factorized Fourier Neural Operator, which usually contains a Lifting Layer, + a Factorized Fourier Block Layer and a Projection Layer. The details can be found in + `A. Tran, A. Mathews, et. al: FACTORIZED FOURIER NEURAL OPERATORS `_. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer. + resolutions (Union[int, list(int)]): The resolutions of the input tensor. + hidden_channels (int): The number of channels of the FNOBlock input and output. Default: ``20``. + lifting_channels (int): The number of channels of the lifting layer mid channels. Default: None. + projection_channels (int): The number of channels of the projection layer mid channels. Default: ``128``. + factor (int): The number of neurons in the hidden layer of a feedforward network. Default: ``1``. + n_layers (int): The number that Fourier Layer nests. Default: ``4``. + n_ff_layers (int): The number of layers (hidden layers) in the feedforward neural network. Default: ``2``. + ff_weight_norm (bool): Whether to do weight normalization in feedforward or not. Used as a reserved function + interface, the weight normalization is not supported in feedforward. Default: ``False``. + layer_norm (bool): Whether to do layer normalization in feedforward or not. Default: ``True``. + share_weight (bool): Whether to share weights between SpectralConv layers or not. Default: ``False``. + r_padding (int): The number used to pad a tensor on the right in a certain dimension. Default: ``0``. + data_format (str): The input data channel sequence. Default: ``channels_last``. + positional_embedding (bool): Whether to embed positional information or not. Default: ``True``. + dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``. + ffno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for + the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. + + Outputs: + Tensor, the output of this FNOBlocks. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, out\_channels)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `out_channels` is not an int. + TypeError: If `hidden_channels` is not an int. + TypeError: If `lifting_channels` is not an int. + TypeError: If `projection_channels` is not an int. + TypeError: If `factor` is not an int. + TypeError: If `n_layers` is not an int. + TypeError: If `n_ff_layers` is not an int. + TypeError: If `ff_weight_norm` is not a Boolean value. + ValueError: If `ff_weight_norm` is not ``False``. + TypeError: If `layer_norm` is not a Boolean value. + TypeError: If `share_weight` is not a Boolean value. + TypeError: If `r_padding` is not an int. + TypeError: If `data_format` is not a str. + TypeError: If `positional_embedding` is not a bool. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> import mindflow + >>> from mindspore import Tensor + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell import FFNO2D + >>> data = Tensor(np.ones([2, 128, 128, 3]), mstype.float32) + >>> net = FFNO2D(in_channels=3, out_channels=3, n_modes=[20, 20], resolutions=[128, 128]) + >>> out = net(data) + >>> print(data.shape, out.shape) + (2, 128, 128, 3) (2, 128, 128, 3) + """ + + def __init__( + self, + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels=20, + lifting_channels=None, + projection_channels=128, + factor=1, + n_layers=4, + n_ff_layers=2, + ff_weight_norm=False, + layer_norm=True, + share_weight=False, + r_padding=0, + data_format="channels_last", + positional_embedding=True, + dft_compute_dtype=mstype.float32, + ffno_compute_dtype=mstype.float16 + ): + n_modes, resolutions = validate_and_expand_dimensions(2, n_modes, resolutions) + super().__init__( + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels, + lifting_channels, + projection_channels, + factor, + n_layers, + n_ff_layers, + ff_weight_norm, + layer_norm, + share_weight, + r_padding, + data_format, + positional_embedding, + dft_compute_dtype, + ffno_compute_dtype + ) + + +class FFNO3D(FFNO): + r""" + The 3D Factorized Fourier Neural Operator, which usually contains a Lifting Layer, + a Factorized Fourier Block Layer and a Projection Layer. The details can be found in + `A. Tran, A. Mathews, et. al: FACTORIZED FOURIER NEURAL OPERATORS `_. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer. + resolutions (Union[int, list(int)]): The resolutions of the input tensor. + hidden_channels (int): The number of channels of the FNOBlock input and output. Default: ``20``. + lifting_channels (int): The number of channels of the lifting layer mid channels. Default: None. + projection_channels (int): The number of channels of the projection layer mid channels. Default: ``128``. + factor (int): The number of neurons in the hidden layer of a feedforward network. Default: ``1``. + n_layers (int): The number that Fourier Layer nests. Default: ``4``. + n_ff_layers (int): The number of layers (hidden layers) in the feedforward neural network. Default: ``2``. + ff_weight_norm (bool): Whether to do weight normalization in feedforward or not. Used as a reserved function + interface, the weight normalization is not supported in feedforward. Default: ``False``. + layer_norm (bool): Whether to do layer normalization in feedforward or not. Default: ``True``. + share_weight (bool): Whether to share weights between SpectralConv layers or not. Default: ``False``. + r_padding (int): The number used to pad a tensor on the right in a certain dimension. Default: ``0``. + data_format (str): The input data channel sequence. Default: ``channels_last``. + positional_embedding (bool): Whether to embed positional information or not. Default: ``True``. + dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``. + ffno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for + the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. + + Outputs: + Tensor, the output of this FNOBlocks. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, out\_channels)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `out_channels` is not an int. + TypeError: If `hidden_channels` is not an int. + TypeError: If `lifting_channels` is not an int. + TypeError: If `projection_channels` is not an int. + TypeError: If `factor` is not an int. + TypeError: If `n_layers` is not an int. + TypeError: If `n_ff_layers` is not an int. + TypeError: If `ff_weight_norm` is not a Boolean value. + ValueError: If `ff_weight_norm` is not ``False``. + TypeError: If `layer_norm` is not a Boolean value. + TypeError: If `share_weight` is not a Boolean value. + TypeError: If `r_padding` is not an int. + TypeError: If `data_format` is not a str. + TypeError: If `positional_embedding` is not a bool. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> import mindflow + >>> from mindspore import Tensor + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell import FFNO3D + >>> data = Tensor(np.ones([2, 128, 128, 128, 3]), mstype.float32) + >>> net = FFNO3D(in_channels=3, out_channels=3, n_modes=[20, 20, 20], resolutions=[128, 128, 128]) + >>> out = net(data) + >>> print(data.shape, out.shape) + (2, 128, 128, 128, 3) (2, 128, 128, 128, 3) + """ + + def __init__( + self, + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels=20, + lifting_channels=None, + projection_channels=128, + factor=1, + n_layers=4, + n_ff_layers=2, + ff_weight_norm=False, + layer_norm=True, + share_weight=False, + r_padding=0, + data_format="channels_last", + positional_embedding=True, + dft_compute_dtype=mstype.float32, + ffno_compute_dtype=mstype.float16 + ): + n_modes, resolutions = validate_and_expand_dimensions(3, n_modes, resolutions) + super().__init__( + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels, + lifting_channels, + projection_channels, + factor, + n_layers, + n_ff_layers, + ff_weight_norm, + layer_norm, + share_weight, + r_padding, + data_format, + positional_embedding, + dft_compute_dtype, + ffno_compute_dtype + ) diff --git a/mindscience/models/neural_operator/ffno_sp.py b/mindscience/models/neural_operator/ffno_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..5737112a88f461ca2d81df8e2f2c55bb605a7cee --- /dev/null +++ b/mindscience/models/neural_operator/ffno_sp.py @@ -0,0 +1,465 @@ +'''' +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +import mindspore as ms +import mindspore.common.dtype as mstype +from mindspore import nn, ops, Tensor, Parameter, ParameterTuple, mint +from mindspore.common.initializer import XavierNormal, initializer +from ...common.math import get_grid_1d, get_grid_2d, get_grid_3d +from ...sciops.fourier import RDFTn, IRDFTn + + +class FeedForward(nn.Cell): + """FeedForward cell""" + + def __init__(self, dim, factor, ff_weight_norm, n_layers, layer_norm, dropout): + super().__init__() + self.layers = nn.CellList() + for i in range(n_layers): + in_dim = dim if i == 0 else dim * factor + out_dim = dim if i == n_layers - 1 else dim * factor + layer = nn.SequentialCell([ + nn.Dense(in_dim, out_dim, has_bias=True) if not ff_weight_norm else nn.Identity(), + nn.Dropout(p=dropout), + nn.ReLU() if i < n_layers - 1 else nn.Identity(), + nn.LayerNorm((out_dim,), epsilon=1e-5) if layer_norm and i == n_layers - 1 else nn.Identity()]) + self.layers.append(layer) + + def construct(self, x): + for layer in self.layers: + x = layer(x) + return x + + +class SpectralConv(nn.Cell): + """Base Class for Fourier Layer, including DFT, factorization, linear transform, and Inverse DFT""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff, + fourier_weight, factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, filter_mode, + compute_dtype=mstype.float32): + super().__init__() + self.einsum_flag = tuple([int(s) for s in ms.__version__.split('.')]) >= (2, 5, 0) + self.in_channels = in_channels + self.out_channels = out_channels + if isinstance(n_modes, int): + n_modes = [n_modes] + self.n_modes = n_modes + if isinstance(resolutions, int): + resolutions = [resolutions] + self.resolutions = resolutions + if len(self.n_modes) != len(self.resolutions): + raise ValueError( + "The dimension of n_modes should be equal to that of resolutions, \ + but got dimension of n_modes {} and dimension of resolutions {}".format(len(self.n_modes), + len(self.resolutions))) + self.compute_dtype = compute_dtype + self.use_fork = use_fork + self.fourier_weight = fourier_weight + self.filter_mode = filter_mode + + if not self.fourier_weight: + param_list = [] + for i, n_mode in enumerate(self.n_modes): + weight_re = Tensor(ops.ones((in_channels, out_channels, n_mode)), mstype.float32) + weight_im = Tensor(ops.ones((in_channels, out_channels, n_mode)), mstype.float32) + + w_re = Parameter(initializer(XavierNormal(), weight_re.shape, mstype.float32), name=f'w_re_{i}', + requires_grad=True) + w_im = Parameter(initializer(XavierNormal(), weight_im.shape, mstype.float32), name=f'w_im_{i}', + requires_grad=True) + + param_list.append(w_re) + param_list.append(w_im) + + self.fourier_weight = ParameterTuple([param for param in param_list]) + + if use_fork: + self.forecast_ff = forecast_ff + if not self.forecast_ff: + self.forecast_ff = FeedForward( + out_channels, factor, ff_weight_norm, n_ff_layers, layer_norm, dropout) + + self.backcast_ff = backcast_ff + if not self.backcast_ff: + self.backcast_ff = FeedForward( + out_channels, factor, ff_weight_norm, n_ff_layers, layer_norm, dropout) + + self._positional_embedding, self._input_perm, self._output_perm = self._transpose(len(self.resolutions)) + + def construct(self, x: Tensor): + raise NotImplementedError() + + def _fourier_dimension(self, n, mode, n_dim): + """" n- shape - 3D: S1 S2 S3 / 2D: M N / 1D: C + mode - output length - n//2 +1 + dim - 3D: -1 -2 -3 / 2D: -1 -2 / 1D: -1 """ + dft_cell = RDFTn(shape=n, dim=n_dim, norm='ortho', modes=mode, compute_dtype=self.compute_dtype) + idft_cell = IRDFTn(shape=n, dim=n_dim, norm='ortho', modes=mode, compute_dtype=self.compute_dtype) + + return dft_cell, idft_cell + + def _einsum(self, inputs, weights, dim): + """The Einstein multiplication function""" + res_len = len(self.resolutions) + + if res_len not in [1, 2, 3]: + raise ValueError( + "The length of input resolutions dimensions should be in [1, 2, 3], but got: {}".format(res_len)) + + if self.einsum_flag: + expressions = { + ('x', 1): 'bix,iox->box', + ('x', 2): 'bixy,iox->boxy', + ('y', 2): 'bixy,ioy->boxy', + ('x', 3): 'bixyz,iox->boxyz', + ('y', 3): 'bixyz,ioy->boxyz', + ('z', 3): 'bixyz,ioz->boxyz' + } + + key = (dim, res_len) + if key not in expressions: + raise ValueError(f"Unsupported type of the last dim of weight: {dim}") + + out = mint.einsum(expressions[key], inputs, weights) + + else: + _, weight_out, weight_dim = weights.shape + batch_size, inputs_in = inputs.shape[0], inputs.shape[1] + weights_perm = (2, 0, 1) + + if res_len == 1: + if dim == 'x': + input_perm = (2, 0, 1) + output_perm = (1, 2, 0) + else: + raise ValueError(f"Unsupported type of the last dim of weight: {dim}") + + inputs = ops.transpose(inputs, input_perm=input_perm) + weights = ops.transpose(weights, input_perm=weights_perm) + out = ops.bmm(inputs, weights) + out = ops.transpose(out, input_perm=output_perm) + elif res_len == 2: + if dim == 'y': + input_perm = (3, 0, 2, 1) + output_perm = (1, 3, 2, 0) + elif dim == 'x': + input_perm = (2, 0, 3, 1) + output_perm = (1, 3, 0, 2) + else: + raise ValueError(f"Unsupported type of the last dim of weight: {dim}") + + inputs = ops.transpose(inputs, input_perm=input_perm) + inputs = ops.reshape(inputs, (weight_dim, -1, inputs_in)) + weights = ops.transpose(weights, input_perm=weights_perm) + out = ops.bmm(inputs, weights) + out = ops.reshape(out, (weight_dim, batch_size, -1, weight_out)) + out = ops.transpose(out, input_perm=output_perm) + else: + input_dim1, input_dim2, input_dim3 = inputs.shape[2], inputs.shape[3], inputs.shape[4] + + if dim == 'z': + input_perm = (4, 0, 2, 3, 1) + output_perm = (1, 4, 2, 3, 0) + reshape_dim = input_dim1 + elif dim == 'y': + input_perm = (3, 0, 4, 2, 1) + output_perm = (1, 4, 3, 0, 2) + reshape_dim = input_dim3 + elif dim == 'x': + input_perm = (2, 0, 3, 4, 1) + output_perm = (1, 4, 0, 2, 3) + reshape_dim = input_dim2 + else: + raise ValueError(f"Unsupported type of the last dim of weight: {dim}") + + inputs = ops.transpose(inputs, input_perm=input_perm) + inputs = ops.reshape(inputs, (weight_dim, -1, inputs_in)) + weights = ops.transpose(weights, input_perm=weights_perm) + out = ops.bmm(inputs, weights) + out = ops.reshape(out, (weight_dim, batch_size, reshape_dim, -1, weight_out)) + out = ops.transpose(out, input_perm=output_perm) + + return out + + def _transpose(self, n_dim): + """transpose tensor""" + if n_dim == 1: + positional_embedding = Tensor(get_grid_1d(resolution=self.resolutions)) + input_perm = (0, 2, 1) + output_perm = (0, 2, 1) + elif n_dim == 2: + positional_embedding = Tensor(get_grid_2d(resolution=self.resolutions)) + input_perm = (0, 2, 3, 1) + output_perm = (0, 3, 1, 2) + elif n_dim == 3: + positional_embedding = Tensor(get_grid_3d(resolution=self.resolutions)) + input_perm = (0, 2, 3, 4, 1) + output_perm = (0, 4, 1, 2, 3) + else: + raise ValueError( + "The length of input resolutions dimensions should be in [1, 2, 3], but got: {}".format(n_dim)) + return positional_embedding, input_perm, output_perm + + def _complex_mul(self, input_re, input_im, weight_re, weight_im, dim): + """(a + bj) * (c + dj) = (ac - bd) + (ad + bc)j""" + out_re = self._einsum(input_re, weight_re, dim) - self._einsum(input_im, weight_im, dim) + out_im = self._einsum(input_re, weight_im, dim) + self._einsum(input_im, weight_re, dim) + + return out_re, out_im + + +class SpectralConv1d(SpectralConv): + """1D Fourier layer. It does DFT, factorization, linear transform, and Inverse DFT.""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff, + fourier_weight, factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, r_padding, + filter_mode, compute_dtype=mstype.float32): + super().__init__(in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff, fourier_weight, + factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, filter_mode) + + self._dft1_x_cell, self._idft1_x_cell = self._fourier_dimension(resolutions[0] + r_padding, n_modes[0], -1) + + def construct(self, x: Tensor): + x = self.construct_fourier(x) + b = self.backcast_ff(x) + f = self.forecast_ff(x) if self.use_fork else None + + return b, f + + def construct_fourier(self, x): + """1D Fourier layer.""" + x = ops.transpose(x, input_perm=self._output_perm) # x shape: batch, in_dim, grid_size + + x_ft_re = x + + x_ftx_re, x_ftx_im = self._dft1_x_cell(x_ft_re) + + x_ftx_re_part = x_ftx_re[:, :, :self.n_modes[0]] + x_ftx_im_part = x_ftx_im[:, :, :self.n_modes[0]] + + re0, re1, re2 = x_ftx_re.shape + im0, im1, im2 = x_ftx_im.shape + out_ftx_remain_re = ops.zeros((re0, re1, re2 - self.n_modes[0])) + out_ftx_remain_im = ops.zeros((im0, im1, im2 - self.n_modes[0])) + + if self.filter_mode == 'full': + ftx_re, ftx_im = self._complex_mul( + x_ftx_re_part, x_ftx_im_part, self.fourier_weight[0], self.fourier_weight[1], 'x') + out_ftx_re = ops.cat([ftx_re, out_ftx_remain_re], axis=2) + out_ftx_im = ops.cat([ftx_im, out_ftx_remain_im], axis=2) + elif self.filter_mode == 'low_pass': + out_ftx_re = ops.cat([x_ftx_re_part, out_ftx_remain_re], axis=2) + out_ftx_im = ops.cat([x_ftx_im_part, out_ftx_remain_im], axis=2) + else: + out_ftx_re = ops.zeros_like(x_ftx_re) + out_ftx_im = ops.zeros_like(x_ftx_im) + + x = self._idft1_x_cell(out_ftx_re, out_ftx_im) + x = ops.transpose(x, input_perm=self._input_perm) + + return x + + +class SpectralConv2d(SpectralConv): + """2D Fourier layer. It does DFT, factorization, linear transform, and Inverse DFT.""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff, + fourier_weight, factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, r_padding, + filter_mode, compute_dtype=mstype.float32): + super().__init__(in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff, fourier_weight, + factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, filter_mode) + + self._dft1_x_cell, self._idft1_x_cell = self._fourier_dimension(resolutions[0] + r_padding, n_modes[0], -2) + self._dft1_y_cell, self._idft1_y_cell = self._fourier_dimension(resolutions[1] + r_padding, n_modes[1], -1) + + def construct(self, x: Tensor): + x = self.construct_fourier(x) + b = self.backcast_ff(x) + f = self.forecast_ff(x) if self.use_fork else None + + return b, f + + def construct_fourier(self, x): + """2D Fourier layer.""" + x = ops.transpose(x, input_perm=self._output_perm) # x shape: batch, in_dim, grid_size, grid_size + + x_ft_re = x + + # Dimesion Y + x_fty_re, x_fty_im = self._dft1_y_cell(x_ft_re) + + x_fty_re_part = x_fty_re[:, :, :, :self.n_modes[1]] + x_fty_im_part = x_fty_im[:, :, :, :self.n_modes[1]] + + re0, re1, re2, re3 = x_fty_re.shape + im0, im1, im2, im3 = x_fty_im.shape + out_fty_remain_re = ops.zeros((re0, re1, re2, re3 - self.n_modes[1])) + out_fty_remain_im = ops.zeros((im0, im1, im2, im3 - self.n_modes[1])) + + if self.filter_mode == 'full': + fty_re, fty_im = self._complex_mul( + x_fty_re_part, x_fty_im_part, self.fourier_weight[2], self.fourier_weight[3], 'y') + out_fty_re = ops.cat([fty_re, out_fty_remain_re], axis=3) + out_fty_im = ops.cat([fty_im, out_fty_remain_im], axis=3) + elif self.filter_mode == 'low_pass': + out_fty_re = ops.cat([x_fty_re_part, out_fty_remain_re], axis=3) + out_fty_im = ops.cat([x_fty_im_part, out_fty_remain_im], axis=3) + else: + out_fty_re = ops.zeros_like(x_fty_re) + out_fty_im = ops.zeros_like(x_fty_im) + + xy = self._idft1_y_cell(out_fty_re, out_fty_im) + + # Dimesion X + x_ftx_re, x_ftx_im = self._dft1_x_cell(x_ft_re) + + x_ftx_re_part = x_ftx_re[:, :, :self.n_modes[0], :] + x_ftx_im_part = x_ftx_im[:, :, :self.n_modes[0], :] + + re0, re1, re2, re3 = x_ftx_re.shape + im0, im1, im2, im3 = x_ftx_im.shape + out_ftx_remain_re = ops.zeros((re0, re1, re2 - self.n_modes[0], re3)) + out_ftx_remain_im = ops.zeros((im0, im1, im2 - self.n_modes[0], im3)) + + if self.filter_mode == 'full': + ftx_re, ftx_im = self._complex_mul( + x_ftx_re_part, x_ftx_im_part, self.fourier_weight[0], self.fourier_weight[1], 'x') + out_ftx_re = ops.cat([ftx_re, out_ftx_remain_re], axis=2) + out_ftx_im = ops.cat([ftx_im, out_ftx_remain_im], axis=2) + elif self.filter_mode == 'low_pass': + out_ftx_re = ops.cat([x_ftx_re_part, out_ftx_remain_re], axis=2) + out_ftx_im = ops.cat([x_ftx_im_part, out_ftx_remain_im], axis=2) + else: + out_ftx_re = ops.zeros_like(x_ftx_re) + out_ftx_im = ops.zeros_like(x_ftx_im) + + xx = self._idft1_x_cell(out_ftx_re, out_ftx_im) + + # Combining Dimensions + x = xx + xy + + x = ops.transpose(x, input_perm=self._input_perm) + + return x + + +class SpectralConv3d(SpectralConv): + """3D Fourier layer. It does DFT, factorization, linear transform, and Inverse DFT.""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff, + fourier_weight, factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, r_padding, + filter_mode, compute_dtype=mstype.float32): + super().__init__(in_channels, out_channels, n_modes, resolutions, forecast_ff, backcast_ff, fourier_weight, + factor, ff_weight_norm, n_ff_layers, layer_norm, use_fork, dropout, filter_mode) + + self._dft1_x_cell, self._idft1_x_cell = self._fourier_dimension(resolutions[0] + r_padding, n_modes[0], -3) + self._dft1_y_cell, self._idft1_y_cell = self._fourier_dimension(resolutions[1] + r_padding, n_modes[1], -2) + self._dft1_z_cell, self._idft1_z_cell = self._fourier_dimension(resolutions[2] + r_padding, n_modes[2], -1) + + def construct(self, x: Tensor): + x = self.construct_fourier(x) + b = self.backcast_ff(x) + f = self.forecast_ff(x) if self.use_fork else None + + return b, f + + def construct_fourier(self, x): + """3D Fourier layer.""" + x = ops.transpose(x, input_perm=self._output_perm) # x shape: batch, in_dim, grid_size, grid_size, grid_size + + x_ft_re = x + + # Dimesion Z + x_ftz_re, x_ftz_im = self._dft1_z_cell(x_ft_re) + + x_ftz_re_part = x_ftz_re[:, :, :, :, :self.n_modes[2]] + x_ftz_im_part = x_ftz_im[:, :, :, :, :self.n_modes[2]] + + re0, re1, re2, re3, re4 = x_ftz_re.shape + im0, im1, im2, im3, im4 = x_ftz_im.shape + out_ftz_remain_re = ops.zeros((re0, re1, re2, re3, re4 - self.n_modes[2])) + out_ftz_remain_im = ops.zeros((im0, im1, im2, im3, im4 - self.n_modes[2])) + + if self.filter_mode == 'full': + ftz_re, ftz_im = self._complex_mul( + x_ftz_re_part, x_ftz_im_part, self.fourier_weight[4], self.fourier_weight[5], 'z') + out_ftz_re = ops.cat([ftz_re, out_ftz_remain_re], axis=4) + out_ftz_im = ops.cat([ftz_im, out_ftz_remain_im], axis=4) + elif self.filter_mode == 'low_pass': + out_ftz_re = ops.cat([x_ftz_re_part, out_ftz_remain_re], axis=4) + out_ftz_im = ops.cat([x_ftz_im_part, out_ftz_remain_im], axis=4) + else: + out_ftz_re = ops.zeros_like(x_ftz_re) + out_ftz_im = ops.zeros_like(x_ftz_im) + + xz = self._idft1_z_cell(out_ftz_re, out_ftz_im) + + # Dimesion Y + x_fty_re, x_fty_im = self._dft1_y_cell(x_ft_re) + + x_fty_re_part = x_fty_re[:, :, :, :self.n_modes[1], :] + x_fty_im_part = x_fty_im[:, :, :, :self.n_modes[1], :] + + re0, re1, re2, re3, re4 = x_fty_re.shape + im0, im1, im2, im3, im4 = x_fty_im.shape + out_fty_remain_re = ops.zeros((re0, re1, re2, re3 - self.n_modes[1], re4)) + out_fty_remain_im = ops.zeros((im0, im1, im2, im3 - self.n_modes[1], im4)) + + if self.filter_mode == 'full': + fty_re, fty_im = self._complex_mul( + x_fty_re_part, x_fty_im_part, self.fourier_weight[2], self.fourier_weight[3], 'y') + out_fty_re = ops.cat([fty_re, out_fty_remain_re], axis=3) + out_fty_im = ops.cat([fty_im, out_fty_remain_im], axis=3) + elif self.filter_mode == 'low_pass': + out_fty_re = ops.cat([x_fty_re_part, out_fty_remain_re], axis=3) + out_fty_im = ops.cat([x_fty_im_part, out_fty_remain_im], axis=3) + else: + out_fty_re = ops.zeros_like(x_fty_re) + out_fty_im = ops.zeros_like(x_fty_im) + + xy = self._idft1_y_cell(out_fty_re, out_fty_im) + + # Dimesion X + x_ftx_re, x_ftx_im = self._dft1_x_cell(x_ft_re) + + x_ftx_re_part = x_ftx_re[:, :, :self.n_modes[0], :, :] + x_ftx_im_part = x_ftx_im[:, :, :self.n_modes[0], :, :] + + re0, re1, re2, re3, re4 = x_ftx_re.shape + im0, im1, im2, im3, im4 = x_ftx_im.shape + out_ftx_remain_re = ops.zeros((re0, re1, re2 - self.n_modes[0], re3, re4)) + out_ftx_remain_im = ops.zeros((im0, im1, im2 - self.n_modes[0], im3, im4)) + + if self.filter_mode == 'full': + ftx_re, ftx_im = self._complex_mul( + x_ftx_re_part, x_ftx_im_part, self.fourier_weight[0], self.fourier_weight[1], 'x') + out_ftx_re = ops.cat([ftx_re, out_ftx_remain_re], axis=2) + out_ftx_im = ops.cat([ftx_im, out_ftx_remain_im], axis=2) + elif self.filter_mode == 'low_pass': + out_ftx_re = ops.cat([x_ftx_re_part, out_ftx_remain_re], axis=2) + out_ftx_im = ops.cat([x_ftx_im_part, out_ftx_remain_im], axis=2) + else: + out_ftx_re = ops.zeros_like(x_ftx_re) + out_ftx_im = ops.zeros_like(x_ftx_im) + + xx = self._idft1_x_cell(out_ftx_re, out_ftx_im) + + # Combining Dimensions + x = xx + xy + xz + + x = ops.transpose(x, input_perm=self._input_perm) + + return x diff --git a/mindscience/models/neural_operator/fno.py b/mindscience/models/neural_operator/fno.py index d4b254f76acea42cca845196913255d3d5286cc0..edd214bb5a8e349567d73234f44d49aceb0fed90 100644 --- a/mindscience/models/neural_operator/fno.py +++ b/mindscience/models/neural_operator/fno.py @@ -16,10 +16,10 @@ ''' # pylint: disable=W0235 -from mindspore import nn, ops, Tensor +from mindspore import nn, ops, Tensor, mint import mindspore.common.dtype as mstype -from .dft import SpectralConv1dDft, SpectralConv2dDft, SpectralConv3dDft +from .fno_sp import SpectralConv1dDft, SpectralConv2dDft, SpectralConv3dDft from ..layers.activation import get_activation from ...common.math import get_grid_1d, get_grid_2d, get_grid_3d from ...utils.check_func import check_param_type @@ -294,7 +294,7 @@ class FNO(nn.Cell): def construct(self, x: Tensor): """construct""" batch_size = x.shape[0] - grid = self._positional_embedding.repeat(batch_size, axis=0).astype(x.dtype) + grid = mint.repeat_interleave(self._positional_embedding.astype(x.dtype), batch_size, dim=0) if self.data_format != "channels_last": x = ops.transpose(x, input_perm=self._output_perm) if self.positional_embedding: diff --git a/mindscience/models/neural_operator/fno_sp.py b/mindscience/models/neural_operator/fno_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..3b2e4e8baa0b742fab1c003d27d15cc84373b050 --- /dev/null +++ b/mindscience/models/neural_operator/fno_sp.py @@ -0,0 +1,243 @@ +'''' +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore import nn, ops, Tensor, Parameter, mint +from mindspore.common.initializer import Zero +from mindspore.ops import operations as P + +from ...sciops.fourier import RDFTn, IRDFTn + + +class SpectralConvDft(nn.Cell): + """Base Class for Fourier Layer, including DFT, linear transform, and Inverse DFT""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + if isinstance(n_modes, int): + n_modes = [n_modes] + self.n_modes = n_modes + if isinstance(resolutions, int): + resolutions = [resolutions] + self.resolutions = resolutions + if len(self.n_modes) != len(self.resolutions): + raise ValueError( + "The dimension of n_modes should be equal to that of resolutions, \ + but got dimension of n_modes {} and dimension of resolutions {}".format(len(self.n_modes), + len(self.resolutions))) + self.compute_dtype = compute_dtype + + def construct(self, x: Tensor): + raise NotImplementedError() + + def _einsum(self, inputs, weights): + weights = weights.expand_dims(0) + inputs = inputs.expand_dims(2) + out = inputs * weights + return out.sum(1) + + +class SpectralConv1dDft(SpectralConvDft): + """1D Fourier Layer. It does DFT, linear transform, and Inverse DFT.""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): + super().__init__(in_channels, out_channels, n_modes, resolutions) + self._scale = (1. / (self.in_channels * self.out_channels)) + w_re = Tensor(self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0]), + dtype=mstype.float32) + w_im = Tensor(self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0]), + dtype=mstype.float32) + self._w_re = Parameter(w_re, requires_grad=True) + self._w_im = Parameter(w_im, requires_grad=True) + self._dft1_cell = RDFTn( + shape=(self.resolutions[0],), norm='ortho', modes=self.n_modes[0], compute_dtype=self.compute_dtype) + self._idft1_cell = IRDFTn( + shape=(self.resolutions[0],), norm='ortho', modes=self.n_modes[0], compute_dtype=self.compute_dtype) + + def construct(self, x: Tensor): + x_re = x + x_ft_re, x_ft_im = self._dft1_cell(x_re) + w_re = P.Cast()(self._w_re, self.compute_dtype) + w_im = P.Cast()(self._w_im, self.compute_dtype) + out_ft_re = self._einsum(x_ft_re[:, :, :self.n_modes[0]], w_re) - self._einsum(x_ft_im[:, :, :self.n_modes[0]], + w_im) + out_ft_im = self._einsum(x_ft_re[:, :, :self.n_modes[0]], w_im) + self._einsum(x_ft_im[:, :, :self.n_modes[0]], + w_re) + + x = self._idft1_cell(out_ft_re, out_ft_im) + + return x + + +class SpectralConv2dDft(SpectralConvDft): + """2D Fourier Layer. It does DFT, linear transform, and Inverse DFT.""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): + super().__init__(in_channels, out_channels, n_modes, resolutions) + self._scale = (1. / (self.in_channels * self.out_channels)) + w_re1 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype) + w_im1 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype) + w_re2 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype) + w_im2 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype) + + self._w_re1 = Parameter(w_re1, requires_grad=True) + self._w_im1 = Parameter(w_im1, requires_grad=True) + self._w_re2 = Parameter(w_re2, requires_grad=True) + self._w_im2 = Parameter(w_im2, requires_grad=True) + + self._dft2_cell = RDFTn(shape=(self.resolutions[0], self.resolutions[1]), norm='ortho', + modes=(self.n_modes[0], self.n_modes[1]), compute_dtype=self.compute_dtype) + self._idft2_cell = IRDFTn(shape=(self.resolutions[0], self.resolutions[1]), norm='ortho', + modes=(self.n_modes[0], self.n_modes[1]), compute_dtype=self.compute_dtype) + self._mat = Tensor(shape=(1, self.out_channels, self.resolutions[1] - 2 * self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype, init=Zero()) + self._concat = ops.Concat(-2) + + def construct(self, x: Tensor): + x_re = x + x_ft_re, x_ft_im = self._dft2_cell(x_re) + + out_ft_re1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1) - self._einsum( + x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1) + out_ft_im1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1) + self._einsum( + x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1) + + out_ft_re2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2) - self._einsum( + x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2) + out_ft_im2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2) + self._einsum( + x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2) + + batch_size = x.shape[0] + mat = mint.repeat_interleave(self._mat, batch_size, 0) + out_re = self._concat((out_ft_re1, mat, out_ft_re2)) + out_im = self._concat((out_ft_im1, mat, out_ft_im2)) + + x = self._idft2_cell(out_re, out_im) + + return x + + +class SpectralConv3dDft(SpectralConvDft): + """3D Fourier layer. It does DFT, linear transform, and Inverse DFT.""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): + super().__init__(in_channels, out_channels, n_modes, resolutions) + self._scale = (1 / (self.in_channels * self.out_channels)) + + w_re1 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_im1 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_re2 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_im2 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_re3 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_im3 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_re4 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_im4 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + + self._w_re1 = Parameter(w_re1, requires_grad=True) + self._w_im1 = Parameter(w_im1, requires_grad=True) + self._w_re2 = Parameter(w_re2, requires_grad=True) + self._w_im2 = Parameter(w_im2, requires_grad=True) + self._w_re3 = Parameter(w_re3, requires_grad=True) + self._w_im3 = Parameter(w_im3, requires_grad=True) + self._w_re4 = Parameter(w_re4, requires_grad=True) + self._w_im4 = Parameter(w_im4, requires_grad=True) + + self._dft3_cell = RDFTn(shape=(self.resolutions[0], self.resolutions[1], self.resolutions[2]), norm='ortho', + modes=(self.n_modes[0], self.n_modes[1], self.n_modes[2]), + compute_dtype=self.compute_dtype) + self._idft3_cell = IRDFTn(shape=(self.resolutions[0], self.resolutions[1], self.resolutions[2]), norm='ortho', + modes=(self.n_modes[0], self.n_modes[1], self.n_modes[2]), + compute_dtype=self.compute_dtype) + self._mat_x = Tensor( + shape=(1, self.out_channels, self.resolutions[0] - 2 * self.n_modes[0], self.n_modes[1], self.n_modes[2]), + dtype=self.compute_dtype, init=Zero()) + self._mat_y = Tensor( + shape=(1, self.out_channels, self.resolutions[0], self.resolutions[1] - 2 * self.n_modes[1], + self.n_modes[2]), + dtype=self.compute_dtype, init=Zero()) + self._concat = ops.Concat(-2) + + def construct(self, x: Tensor): + x_re = x + x_ft_re, x_ft_im = self._dft3_cell(x_re) + + out_ft_re1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], + self._w_re1) - self._einsum(x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], + :self.n_modes[2]], self._w_im1) + out_ft_im1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], + self._w_im1) + self._einsum(x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], + :self.n_modes[2]], self._w_re1) + out_ft_re2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], + self._w_re2) - self._einsum(x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], + :self.n_modes[2]], self._w_im2) + out_ft_im2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], + self._w_im2) + self._einsum(x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], + :self.n_modes[2]], self._w_re2) + out_ft_re3 = self._einsum(x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], + self._w_re3) - self._einsum(x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, + :self.n_modes[2]], self._w_im3) + out_ft_im3 = self._einsum(x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], + self._w_im3) + self._einsum(x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, + :self.n_modes[2]], self._w_re3) + out_ft_re4 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], + self._w_re4) - self._einsum(x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, + :self.n_modes[2]], self._w_im4) + out_ft_im4 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], + self._w_im4) + self._einsum(x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, + :self.n_modes[2]], self._w_re4) + + batch_size = x.shape[0] + mat_x = mint.repeat_interleave(self._mat_x, batch_size, 0) + mat_y = mint.repeat_interleave(self._mat_y, batch_size, 0) + + out_re1 = ops.concat((out_ft_re1, mat_x, out_ft_re2), -3) + out_im1 = ops.concat((out_ft_im1, mat_x, out_ft_im2), -3) + + out_re2 = ops.concat((out_ft_re3, mat_x, out_ft_re4), -3) + out_im2 = ops.concat((out_ft_im3, mat_x, out_ft_im4), -3) + out_re = ops.concat((out_re1, mat_y, out_re2), -2) + out_im = ops.concat((out_im1, mat_y, out_im2), -2) + x = self._idft3_cell(out_re, out_im) + + return x diff --git a/mindscience/models/neural_operator/kno1d.py b/mindscience/models/neural_operator/kno1d.py index b4c11bb5ddfcd5bd3d698281b2f234adc110f344..81820de728a5a6d5563e04c6819e28eb6db5bf1b 100644 --- a/mindscience/models/neural_operator/kno1d.py +++ b/mindscience/models/neural_operator/kno1d.py @@ -16,7 +16,7 @@ import mindspore.common.dtype as mstype from mindspore import ops, nn, Tensor -from .dft import SpectralConv1dDft +from .fno_sp import SpectralConv1dDft from ...utils.check_func import check_param_type diff --git a/mindscience/models/neural_operator/kno2d.py b/mindscience/models/neural_operator/kno2d.py index 9674709f952725070d9c333b06055652342596cd..79f9ae98a2a824cb2339287ced8f7c66e3655af5 100644 --- a/mindscience/models/neural_operator/kno2d.py +++ b/mindscience/models/neural_operator/kno2d.py @@ -1,119 +1,119 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""KNO2D""" -import mindspore.common.dtype as mstype -from mindspore import ops, nn, Tensor - -from .dft import SpectralConv2dDft -from ...utils.check_func import check_param_type - - -class KNO2D(nn.Cell): - r""" - The 2-dimensional Koopman Neural Operator (KNO2D) contains a encoder layer and a decoder layer, - multiple Koopman layers. - The details can be found in `KoopmanLab: machine learning for solving complex physics equations - `_. - - Args: - in_channels (int): The number of channels in the input space. Default: ``1``. - channels (int): The number of channels after dimension lifting of the input. Default: ``32``. - modes (int): The number of low-frequency components to keep. Default: ``16``. - resolution (int): The spatial resolution of the input. Default: ``1024``. - depths (int): The number of KNO layers. Default: ``4``. - compute_dtype (dtype.Number): The computation type of dense. Default: ``mstype.float16``. - Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for - the GPU backend, mstype.float16 is recommended for the Ascend backend. - - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. - - Outputs: - Tensor, the output of this KNO network. - - - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. - - Raises: - TypeError: If `in_channels` is not an int. - TypeError: If `channels` is not an int. - TypeError: If `modes` is not an int. - TypeError: If `depths` is not an int. - TypeError: If `resolution` is not an int. - - Supported Platforms: - ``Ascend`` ``GPU`` - - Examples: - >>> import numpy as np - >>> from mindflow.cell.neural_operators import KNO2D - >>> input_ = Tensor(np.ones([32, 64, 64, 10]), mstype.float32) - >>> net = KNO2D() - >>> x, x_reconstruct = net(input_) - >>> print(x.shape, x_reconstruct.shape) - (32, 64, 64, 10) (32, 64, 64, 10) - """ - - def __init__(self, - in_channels=10, - channels=32, - modes=16, - depths=4, - resolution=64, - compute_dtype=mstype.float32): - super().__init__() - check_param_type(in_channels, "in_channels", - data_type=int, exclude_type=bool) - check_param_type(channels, "channels", - data_type=int, exclude_type=bool) - check_param_type(modes, "modes", - data_type=int, exclude_type=bool) - check_param_type(depths, "depths", - data_type=int, exclude_type=bool) - check_param_type(resolution, "resolution", - data_type=int, exclude_type=bool) - self.in_channels = in_channels - self.channels = channels - self.modes = modes - self.depths = depths - self.resolution = resolution - self.enc = nn.Dense(in_channels, channels, has_bias=True) - self.dec = nn.Dense(channels, in_channels, has_bias=True) - self.koopman_layer = SpectralConv2dDft(channels, channels, [modes, modes], [resolution, resolution], - compute_dtype=compute_dtype) - self.w0 = nn.Conv2d(channels, channels, 1, has_bias=True) - - def construct(self, x: Tensor): - """KNO2D forward function. - - Args: - x (Tensor): Input Tensor. - """ - # reconstruct - x_reconstruct = self.enc(x) - x_reconstruct = ops.tanh(x_reconstruct) - x_reconstruct = self.dec(x_reconstruct) - - # predict - x = self.enc(x) - x = ops.tanh(x) - x = x.transpose(0, 3, 1, 2) - x_w = x - for _ in range(self.depths): - x1 = self.koopman_layer(x) - x = ops.tanh(x + x1) - x = ops.tanh(self.w0(x_w) + x) - x = x.transpose(0, 2, 3, 1) - x = self.dec(x) - return x, x_reconstruct +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""KNO2D""" +import mindspore.common.dtype as mstype +from mindspore import ops, nn, Tensor + +from .fno_sp import SpectralConv2dDft +from ...utils.check_func import check_param_type + + +class KNO2D(nn.Cell): + r""" + The 2-dimensional Koopman Neural Operator (KNO2D) contains a encoder layer and a decoder layer, + multiple Koopman layers. + The details can be found in `KoopmanLab: machine learning for solving complex physics equations + `_. + + Args: + in_channels (int): The number of channels in the input space. Default: ``1``. + channels (int): The number of channels after dimension lifting of the input. Default: ``32``. + modes (int): The number of low-frequency components to keep. Default: ``16``. + resolution (int): The spatial resolution of the input. Default: ``1024``. + depths (int): The number of KNO layers. Default: ``4``. + compute_dtype (dtype.Number): The computation type of dense. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for + the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. + + Outputs: + Tensor, the output of this KNO network. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `channels` is not an int. + TypeError: If `modes` is not an int. + TypeError: If `depths` is not an int. + TypeError: If `resolution` is not an int. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindflow.cell.neural_operators import KNO2D + >>> input_ = Tensor(np.ones([32, 64, 64, 10]), mstype.float32) + >>> net = KNO2D() + >>> x, x_reconstruct = net(input_) + >>> print(x.shape, x_reconstruct.shape) + (32, 64, 64, 10) (32, 64, 64, 10) + """ + + def __init__(self, + in_channels=10, + channels=32, + modes=16, + depths=4, + resolution=64, + compute_dtype=mstype.float32): + super().__init__() + check_param_type(in_channels, "in_channels", + data_type=int, exclude_type=bool) + check_param_type(channels, "channels", + data_type=int, exclude_type=bool) + check_param_type(modes, "modes", + data_type=int, exclude_type=bool) + check_param_type(depths, "depths", + data_type=int, exclude_type=bool) + check_param_type(resolution, "resolution", + data_type=int, exclude_type=bool) + self.in_channels = in_channels + self.channels = channels + self.modes = modes + self.depths = depths + self.resolution = resolution + self.enc = nn.Dense(in_channels, channels, has_bias=True) + self.dec = nn.Dense(channels, in_channels, has_bias=True) + self.koopman_layer = SpectralConv2dDft(channels, channels, [modes, modes], [resolution, resolution], + compute_dtype=compute_dtype) + self.w0 = nn.Conv2d(channels, channels, 1, has_bias=True) + + def construct(self, x: Tensor): + """KNO2D forward function. + + Args: + x (Tensor): Input Tensor. + """ + # reconstruct + x_reconstruct = self.enc(x) + x_reconstruct = ops.tanh(x_reconstruct) + x_reconstruct = self.dec(x_reconstruct) + + # predict + x = self.enc(x) + x = ops.tanh(x) + x = x.transpose(0, 3, 1, 2) + x_w = x + for _ in range(self.depths): + x1 = self.koopman_layer(x) + x = ops.tanh(x + x1) + x = ops.tanh(self.w0(x_w) + x) + x = x.transpose(0, 2, 3, 1) + x = self.dec(x) + return x, x_reconstruct diff --git a/mindscience/sciops/__init__.py b/mindscience/sciops/__init__.py index 69a14b29e1ced3fa627e5dada3f5f6ba239fdc1c..2220d4f499b98524fd237e5573211a8551c9818c 100644 --- a/mindscience/sciops/__init__.py +++ b/mindscience/sciops/__init__.py @@ -15,5 +15,6 @@ """ init """ +from .fourier import RDFTn, IRDFTn, DFTn, IDFTn, DCT, IDCT, DST, IDST -__all__ = [] \ No newline at end of file +__all__ = ["RDFTn", "IRDFTn", "DFTn", "IDFTn", "DCT", "IDCT", "DST", "IDST"] diff --git a/mindscience/sciops/fourier.py b/mindscience/sciops/fourier.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d29a13c3eb3f5525a92ea3f5267c621dc61639 --- /dev/null +++ b/mindscience/sciops/fourier.py @@ -0,0 +1,666 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +''' provide complex dft based on the real dft API in mindflow.dft ''' +import numpy as np +import scipy +import mindspore as ms +import mindspore.common.dtype as mstype +from mindspore import nn, ops, Tensor, mint +from mindspore.common.initializer import Zero +from mindspore.ops import operations as P + +from ..utils.check_func import check_param_no_greater, check_param_value + + +class MyRoll(nn.Cell): + ''' Custom defined roll operator to avoid bug in MindSpore ''' + def __init__(self): + super().__init__() + + if ms.get_context('device_target') == 'Ascend' and ms.get_context('mode') == ms.GRAPH_MODE: + self.roller = mint.roll + else: + self.roller = None + + def construct(self, x, shifts, dims): + ''' Same as mint.roll ''' + shifts = np.atleast_1d(shifts).astype(int).tolist() + dims = np.atleast_1d(dims).astype(int).tolist() + + if self.roller: + return self.roller(x, shifts, dims) + + for i, j in zip(shifts, dims): + n = x.shape[j] + x = ops.swapaxes(x, j, 0) + x = ops.cat([x[n - i % n:], x[:n - i % n]], axis=0) + x = ops.swapaxes(x, j, 0) + return x + +class MyFlip(nn.Cell): + ''' Custom defined flip operator to avoid bug in MindSpore ''' + def __init__(self, shape=None, compute_dtype=ms.float32): + super().__init__() + + if shape: + shape = np.atleast_1d(shape).astype(int).tolist() + self.rev_mats = [ms.Tensor(np.eye(n)[::-1], dtype=compute_dtype) for n in shape] + self.cast = P.Cast() + else: + self.rev_mats = None + + def construct(self, x, dims): + ''' same as mint.flip ''' + dims = np.atleast_1d(dims).astype(int).tolist() + + if self.rev_mats is not None: + for i, m in enumerate(self.rev_mats): + self.rev_mats[i] = self.cast(m, x.dtype) + + for i, j in enumerate(dims): + x = ops.swapaxes(x, j, -1) + mat = None + if self.rev_mats is None: + mat = ms.Tensor(np.eye(x.shape[-1])[::-1], dtype=x.dtype) + else: + mat = self.rev_mats[i] + x = mint.matmul(x, mat) # todo: 求导有问题 + x = ops.swapaxes(x, j, -1) + + return x + + +def convert_shape(shape): + ''' convert shape to suitable format ''' + if isinstance(shape, int): + n = shape + elif len(shape) == 1: + n, = shape + else: + raise TypeError("Only support 1D dct/dst, but got shape {}".format(shape)) + return n + + +def convert_params(shape, modes, dim): + ''' convert input arguments to suitable format ''' + shape = tuple(np.atleast_1d(shape).astype(int).tolist()) + ndim = len(shape) + + if dim is None: + dim = tuple([n - ndim for n in range(ndim)]) + else: + dim = tuple(np.atleast_1d(dim).astype(int).tolist()) + + if modes is None or isinstance(modes, int): + modes = tuple([modes] * ndim) + else: + modes = tuple(np.atleast_1d(modes).astype(int).tolist()) + + return shape, modes, dim + + +def check_params(shape, modes, dim): + ''' check lawfulness of input arguments ''' + check_param_no_greater(len(dim), "dim length", 3) + check_param_value(len(shape), "shape length", len(dim)) + check_param_value(len(modes), "modes length", len(dim)) + if np.any(modes): + for i, (m, n) in enumerate(zip(modes, shape)): + # if for last axis mode need to be n//2+1, mode should be set to None + check_param_no_greater(m, f'mode{i+1}', n // 2) + + +class _DFT1d(nn.Cell): + '''One dimensional Discrete Fourier Transformation''' + + def __init__(self, n, mode, last_index, idx=0, scale='sqrtn', inv=False, compute_dtype=mstype.float32): + super().__init__() + + self.n = n + self.dft_mat = scipy.linalg.dft(n, scale=scale) + self.last_index = last_index + self.inv = inv + self.odd = bool(n % 2) + self.idx = idx + self.mode_upper = mode if mode else n // 2 + (self.last_index or self.odd) + self.mode_lower = mode if mode else n - self.mode_upper + self.compute_dtype = compute_dtype + + # generate DFT matrix for positive and negative frequencies + dft_mat_mode = self.dft_mat[:, :self.mode_upper] + self.a_re_upper = Tensor(dft_mat_mode.real, dtype=compute_dtype) + self.a_im_upper = Tensor(dft_mat_mode.imag, dtype=compute_dtype) + + dft_mat_mode = self.dft_mat[:, -self.mode_lower:] + self.a_re_lower = Tensor(dft_mat_mode.real, dtype=compute_dtype) + self.a_im_lower = Tensor(dft_mat_mode.imag, dtype=compute_dtype) + + # the zero matrix to fill the un-transformed modes + m = self.n - (self.mode_upper + self.mode_lower) + if m > 0: + self.mat = Tensor(shape=m, dtype=compute_dtype, init=Zero()) + + self.concat = ops.Concat(axis=-1) + self.cast = P.Cast() + + if self.inv: + self.a_re_upper = self.a_re_upper.T + self.a_im_upper = -self.a_im_upper.T + self.a_re_lower = self.a_re_lower.T + self.a_im_lower = -self.a_im_lower.T + + # last axis is real-transformed, so the inverse is conjugate of the positive frequencies + if last_index: + mode_res = min(self.mode_lower, self.mode_upper - 1) + dft_mat_res = self.dft_mat[:, -mode_res:] + a_re_res = MyFlip()(Tensor(dft_mat_res.real, dtype=compute_dtype), dims=-1) + a_im_res = MyFlip()(Tensor(dft_mat_res.imag, dtype=compute_dtype), dims=-1) + + a_re_res = ops.pad(a_re_res, (1, self.mode_upper - mode_res - 1)) + a_im_res = ops.pad(a_im_res, (1, self.mode_upper - mode_res - 1)) + + self.a_re_upper += a_re_res.T + self.a_im_upper += a_im_res.T + + def swap_axes(self, x_re, x_im): + return x_re.swapaxes(-1, self.idx), x_im.swapaxes(-1, self.idx) + + def complex_matmul(self, x_re, x_im, a_re, a_im): + y_re = ops.matmul(x_re, a_re) - ops.matmul(x_im, a_im) + y_im = ops.matmul(x_im, a_re) + ops.matmul(x_re, a_im) + return y_re, y_im + + def zero_mat(self, dims): + mat = self.mat + for n in dims[::-1]: + mat = mint.repeat_interleave(mat.expand_dims(0), n, 0) + return mat + + def compute_forward(self, x_re, x_im): + ''' Forward transform for rdft ''' + y_re, y_im = self.complex_matmul( + x_re=x_re, x_im=x_im, a_re=self.a_re_upper, a_im=self.a_im_upper) + + if self.last_index: + return y_re, y_im + + y_re2, y_im2 = self.complex_matmul( + x_re=x_re, x_im=x_im, a_re=self.a_re_lower, a_im=self.a_im_lower) + + if self.n == self.mode_upper + self.mode_lower: + y_re = self.concat((y_re, y_re2)) + y_im = self.concat((y_im, y_im2)) + else: + mat = self.zero_mat(x_re.shape[:-1]) + y_re = self.concat((y_re, mat, y_re2)) + y_im = self.concat((y_im, mat, y_im2)) + + return y_re, y_im + + def compute_inverse(self, x_re, x_im): + ''' Inverse transform for irdft ''' + y_re, y_im = self.complex_matmul(x_re=x_re[..., :self.mode_upper], + x_im=x_im[..., :self.mode_upper], + a_re=self.a_re_upper, + a_im=self.a_im_upper) + if self.last_index: + return y_re, y_im + + y_re_res, y_im_res = self.complex_matmul(x_re=x_re[..., -self.mode_lower:], + x_im=x_im[..., -self.mode_lower:], + a_re=self.a_re_lower, + a_im=self.a_im_lower) + return y_re + y_re_res, y_im + y_im_res + + def construct(self, x): + ''' perform 1d rdft/irdft with matmul operations ''' + x_re, x_im = x + x_re, x_im = self.cast(x_re, self.compute_dtype), self.cast(x_im, self.compute_dtype) + x_re, x_im = self.swap_axes(x_re, x_im) + if self.inv: + y_re, y_im = self.compute_inverse(x_re, x_im) + else: + y_re, y_im = self.compute_forward(x_re, x_im) + y_re, y_im = self.swap_axes(y_re, y_im) + return y_re, y_im + + +class _DFTn(nn.Cell): + ''' Base class for n-D DFT transform ''' + def __init__(self, shape, dim=None, norm='backward', modes=None, compute_dtype=mstype.float32): + super().__init__() + + shape, modes, dim = convert_params(shape, modes, dim) + check_params(shape, modes, dim) + + ndim = len(shape) + inv, scale, r2c_flags = self.set_options(ndim, norm) + self.dft1_seq = nn.SequentialCell() + for n, m, r, d in zip(shape, modes, r2c_flags, dim): + self.dft1_seq.append(_DFT1d( + n=n, mode=m, last_index=r, idx=d, scale=scale, inv=inv, compute_dtype=compute_dtype)) + + def set_options(self, ndim, norm): + ''' + Choose the dimensions, normalization, and transformation mode (forward/backward). + Derivative APIs overwrite the options to achieve their specific goals. + ''' + inv = False + scale = { + 'backward': None, + 'forward': 'n', + 'ortho': 'sqrtn', + }[norm] + r2c_flags = np.zeros(ndim, dtype=bool).tolist() + r2c_flags[-1] = True + return inv, scale, r2c_flags + + def construct(self, *args, **kwargs): + raise NotImplementedError + + +class RDFTn(_DFTn): + r""" + 1/2/3D discrete real Fourier transformation on real number. The results should be same as + `scipy.fft.rfftn() `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. + norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward', + same as torch.fft.rfftn + modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **ar** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **br** (Tensor) - Real part of the output tensor, with trailing dimensions aligned with `shape`, + except for the last dimension, which should be shape[-1] / 2 + 1. + - **bi** (Tensor) - Imag part of the output tensor, with trailing dimensions aligned with `shape`, + except for the last dimension, which should be shape[-1] / 2 + 1. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.core import RDFTn + >>> ar = ops.rand((2, 32, 512)) + >>> dft_cell = RDFTn(x.shape[-2:]) + >>> br, bi = dft_cell(ar) + >>> print(br.shape) + (2, 32, 257) + """ + def construct(self, ar): + ''' perform n-dimensional rDFT on real tensor ''' + # n-D Fourier transform with last axis being real-transformed, output dimension (..., m, n//2+1) + # the last ndim dimensions of ar must accord with shape + return self.dft1_seq((ar, ar * 0)) + + +class IRDFTn(_DFTn): + r""" + 1/2/3D discrete inverse real Fourier transformation on complex number. The results should be same as + `scipy.fft.irfftn() `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. + norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward', + same as torch.fft.irfftn + modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **ar** (Tensor) - Real part of the tensor to be transformed, with trailing dimensions aligned with `shape`, + except for the last dimension, which should be shape[-1] / 2 + 1. + - **ai** (Tensor) - Imag part of the tensor to be transformed, with trailing dimensions aligned with `shape`, + except for the last dimension, which should be shape[-1] / 2 + 1. + + Outputs: + - **br** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.core import IRDFTn + >>> ar = ops.rand((2, 32, 257)) + >>> ai = ops.rand((2, 32, 257)) + >>> dft_cell = IRDFTn(x.shape[-2:]) + >>> br = dft_cell(ar) + >>> print(br.shape) + (2, 32, 512) + """ + def set_options(self, ndim, norm): + inv = True + scale = { + 'forward': None, + 'backward': 'n', + 'ortho': 'sqrtn', + }[norm] + r2c_flags = np.zeros(ndim, dtype=bool).tolist() + r2c_flags[-1] = True + return inv, scale, r2c_flags + + def construct(self, ar, ai): + ''' perform n-dimensional irDFT on complex tensor and output real tensor ''' + return self.dft1_seq((ar, ai))[0] + + +class DFTn(_DFTn): + r""" + 1/2/3D discrete Fourier transformation on complex number. The results should be same as + `scipy.fft.fftn() `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. + norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward', + same as torch.fft.irfftn + modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **ar** (Tensor) - Real part of the tensor to be transformed, with trailing dimensions aligned with `shape`. + - **ai** (Tensor) - Imag part of the tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **br** (Tensor) - Real part of the output tensor, with trailing dimensions aligned with `shape`. + - **bi** (Tensor) - Imag part of the output tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import DFTn + >>> ar = ops.rand((2, 32, 512)) + >>> ai = ops.rand((2, 32, 512)) + >>> dft_cell = DFTn(x.shape[-2:]) + >>> br, bi = dft_cell(ar, ai) + >>> print(br.shape) + (2, 32, 512) + """ + def set_options(self, ndim, norm): + inv = False + scale = { + 'forward': 'n', + 'backward': None, + 'ortho': 'sqrtn', + }[norm] + r2c_flags = np.zeros(ndim, dtype=bool).tolist() + return inv, scale, r2c_flags + + def construct(self, ar, ai): + ''' perform n-dimensional DFT on complex tensor ''' + # n-D complex Fourier transform, output dimension (..., m, n) + return self.dft1_seq((ar, ai)) + + +class IDFTn(DFTn): + r""" + 1/2/3D discrete inverse Fourier transformation on complex number. The results should be same as + `scipy.fft.ifftn() `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. + norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward', + same as torch.fft.irfftn + modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **ar** (Tensor) - Real part of the tensor to be transformed, with trailing dimensions aligned with `shape`. + - **ai** (Tensor) - Imag part of the tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **br** (Tensor) - Real part of the output tensor, with trailing dimensions aligned with `shape`. + - **bi** (Tensor) - Imag part of the output tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import DFTn + >>> ar = ops.rand((2, 32, 512)) + >>> ai = ops.rand((2, 32, 512)) + >>> dft_cell = DFTn(x.shape[-2:]) + >>> br, bi = dft_cell(ar, ai) + >>> print(br.shape) + (2, 32, 512) + """ + def set_options(self, ndim, norm): + inv = True + scale = { + 'forward': None, + 'backward': 'n', + 'ortho': 'sqrtn', + }[norm] + r2c_flags = np.zeros(ndim, dtype=bool).tolist() + return inv, scale, r2c_flags + + +class DCT(nn.Cell): + r""" + 1D discrete cosine transformation on real number on the last axis. The results should be same as + `scipy.fft.dct() `_ . + Reference: `Type 2 DCT using N FFT (Makhoul) `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + Must be a length-1 tuple. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **a** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **b** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import DCT + >>> a = ops.rand((2, 32, 512)) + >>> dft_cell = DCT(x.shape[-1:]) + >>> b = dft_cell(a) + >>> print(b.shape) + (2, 32, 512) + """ + def __init__(self, shape, compute_dtype=mstype.float32): + super().__init__() + + n = convert_shape(shape) + + self.dft_cell = DFTn(n, compute_dtype=compute_dtype) + + w = Tensor(np.arange(n) * np.pi / (2 * n), dtype=compute_dtype) + self.cosw = ops.cos(w) + self.sinw = ops.sin(w) + + self.fliper = MyFlip((n // 2,), compute_dtype) + + def construct(self, a): + ''' perform 1-dimensional DCT on real tensor ''' + b_half1 = a[..., ::2] + b_half2 = self.fliper(a[..., 1::2], dims=-1) + b = ops.cat([b_half1, b_half2], axis=-1) + cr, ci = self.dft_cell(b, b * 0) + return 2 * (cr * self.cosw + ci * self.sinw) + + +class IDCT(nn.Cell): + r""" + 1D inverse discrete cosine transformation on real number on the last axis. The results should be same as + `scipy.fft.dct() `_ . + Reference: `A fast cosine transform in one and two dimensions + `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + Must be a length-1 tuple. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **a** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **b** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import IDCT + >>> a = ops.rand((2, 32, 512)) + >>> dft_cell = IDCT(x.shape[-1:]) + >>> b = dft_cell(a) + >>> print(b.shape) + (2, 32, 512) + """ + def __init__(self, shape, compute_dtype=mstype.float32): + super().__init__() + + n = convert_shape(shape) + + # assert n % 2 == 0, 'only support even length' # n has to be even, or IRDFTn would fail + + self.dft_cell = IRDFTn(n, compute_dtype=compute_dtype) + + w = Tensor(np.arange(n // 2 + 1) * np.pi / (2 * n), dtype=compute_dtype) + self.cosw = ops.cos(w) + self.sinw = ops.sin(w) + + self.fliper = MyFlip((n // 2,), compute_dtype) + + def construct(self, a): + ''' perform 1-dimensional iDCT on real tensor ''' + n = a.shape[-1] + + br = a[..., :n // 2 + 1] + bi = ops.pad(self.fliper(- a[..., -(n // 2):], dims=-1), (1, 0)) + vr = (br * self.cosw - bi * self.sinw) / 2 + vi = (bi * self.cosw + br * self.sinw) / 2 + + c = self.dft_cell(vr, vi) # (..., n) + c1 = c[..., :(n + 1) // 2] + c2 = self.fliper(c[..., (n + 1) // 2:], dims=-1) + d1 = ops.pad(c1.reshape(-1)[..., None], (0, 1)).reshape(*c1.shape[:-1], -1) + d2 = ops.pad(c2.reshape(-1)[..., None], (1, 0)).reshape(*c2.shape[:-1], -1) + # in case n is odd, d1 and d2 need to be aligned + d1 = d1[..., :n] + d2 = ops.pad(d2, (0, n % 2)) + return d1 + d2 + + +class DST(nn.Cell): + r""" + 1D discrete sine transformation on real number on the last axis. The results should be same as + `scipy.fft.dct() `_ . + Reference: `Wikipedia `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + Must be a length-1 tuple. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **a** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **b** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import DST + >>> a = ops.rand((2, 32, 512)) + >>> dft_cell = DST(x.shape[-1:]) + >>> b = dft_cell(a) + >>> print(b.shape) + (2, 32, 512) + """ + def __init__(self, shape, compute_dtype=mstype.float32): + super().__init__() + n = convert_shape(shape) + self.dft_cell = DCT(n, compute_dtype=compute_dtype) + multiplier = np.ones(n) + multiplier[..., 1::2] *= -1 + self.multiplier = Tensor(multiplier, dtype=compute_dtype) + + self.fliper = MyFlip((n,), compute_dtype) + + def construct(self, a): + ''' perform 1-dimensional DST on real tensor ''' + return self.fliper(self.dft_cell(a * self.multiplier), dims=-1) + + +class IDST(nn.Cell): + r""" + 1D inverse discrete sine transformation on real number on the last axis. The results should be same as + `scipy.fft.dct() `_ . + Reference: `Wikipedia `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + Must be a length-1 tuple. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **a** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **b** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import IDST + >>> a = ops.rand((2, 32, 512)) + >>> dft_cell = IDST(x.shape[-1:]) + >>> b = dft_cell(a) + >>> print(b.shape) + (2, 32, 512) + """ + def __init__(self, shape, compute_dtype=mstype.float32): + super().__init__() + n = convert_shape(shape) + self.dft_cell = IDCT(n, compute_dtype=compute_dtype) + multiplier = np.ones(n) + multiplier[..., 1::2] *= -1 + self.multiplier = Tensor(multiplier, dtype=compute_dtype) + + self.fliper = MyFlip((n,), compute_dtype) + + def construct(self, a): + ''' perform 1-dimensional iDST on real tensor ''' + return self.dft_cell(self.fliper(a, dims=-1)) * self.multiplier diff --git a/tests/common/test_optimizers.py b/tests/common/test_optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..8d829845a1db120e4e75fc3f6c7848d913db7c9d --- /dev/null +++ b/tests/common/test_optimizers.py @@ -0,0 +1,277 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Optimizers Test Case""" +import os +import random +import sys + +import pytest +import numpy as np + +import mindspore as ms +from mindspore import ops, set_seed, nn, mint +from mindspore import dtype as mstype +from mindscience.models.layers.unet2d import UNet2D, Down +from mindscience.models.transformer.attention import TransformerBlock, MultiHeadAttention, FeedForward +from mindscience.common.optimizers import AdaHessian + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) +sys.path.append(PROJECT_ROOT) + +# pylint: disable=wrong-import-position + +from tools import FP32_RTOL + +# pylint: enable=wrong-import-position + +set_seed(0) +np.random.seed(0) +random.seed(0) + +test_data_path = '/home/workspace/mindspore_dataset/mindscience/mindflow/optimizers' + + +class TestAdaHessianAccuracy(AdaHessian): + ''' Child class for testing the accuracy of AdaHessian optimizer ''' + + def gen_rand_vecs(self, grads): + ''' generate certain vector for accuracy test ''' + return [ms.Tensor(np.arange(p.size).reshape(p.shape) - p.size // 2, dtype=ms.float32) for p in grads] + + +class TestUNet2D(UNet2D): + ''' Child class for testing optimizing UNet with AdaHessian ''' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + class TestDown(Down): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + in_channels = args[0] + kernel_size = kwargs['kernel_size'] + stride = kwargs['stride'] + # replace the `maxpool` layer in the original UNet with `conv` to avoid `vjp` problem + self.maxpool = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride) + + self.layers_down = nn.CellList() + for i in range(self.n_layers): + self.layers_down.append(TestDown(self.base_channels * 2**i, self.base_channels * 2 ** (i+1), + kernel_size=self.kernel_size, stride=self.stride, + activation=self.activation, enable_bn=self.enable_bn)) + + +class TestAttentionBlock(TransformerBlock): + ''' Child class for testing optimizing Attention with AdaHessian ''' + + def __init__(self, + in_channels: int, + num_heads: int, + enable_flash_attn: bool = False, + fa_dtype: mstype = mstype.bfloat16, + drop_mode: str = "dropout", + dropout_rate: float = 0.0, + compute_dtype: mstype = mstype.float32, + ): + super().__init__(in_channels=in_channels, + num_heads=num_heads, + enable_flash_attn=enable_flash_attn, + fa_dtype=fa_dtype, + drop_mode=drop_mode, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + + class TestMlp(FeedForward): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.act_fn = nn.ReLU() # replace `gelu` with `relu` to avoid `vjp` problem + + class TestMultiHeadAttention(MultiHeadAttention): + ''' MultiHeadAttention modified to support vjp ''' + def get_qkv(self, x: ms.Tensor) -> tuple[ms.Tensor]: + ''' use masks to select out q, k, v, instead of tensor reshaping & indexing ''' + b, n, c_full = x.shape + c = c_full // self.num_heads + + # use matmul with masks to select out q, k, v to avoid vjp problem + q_mask = ms.Tensor(np.vstack([np.eye(c), np.zeros([2 * c, c])]), dtype=self.compute_dtype) + k_mask = ms.Tensor(np.vstack([np.zeros([c, c]), np.eye(c), np.zeros([c, c])]), dtype=self.compute_dtype) + v_mask = ms.Tensor(np.vstack([np.zeros([2 * c, c]), np.eye(c)]), dtype=self.compute_dtype) + + qkv = self.qkv(x) + qkv = qkv.reshape(b, n, self.num_heads, -1).swapaxes(1, 2) + + q = mint.matmul(qkv, q_mask) + k = mint.matmul(qkv, k_mask) + v = mint.matmul(qkv, v_mask) + + return q, k, v + + self.ffn = TestMlp( + in_channels=in_channels, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + self.attention = TestMultiHeadAttention( + in_channels=in_channels, + num_heads=num_heads, + enable_flash_attn=enable_flash_attn, + fa_dtype=fa_dtype, + drop_mode=drop_mode, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_adahessian_accuracy(mode): + """ + Feature: AdaHessian forward accuracy test + Description: Test the accuracy of the AdaHessian optimizer in both GRAPH_MODE and PYNATIVE_MODE + with input data specified in the code below. + The expected output is compared to a reference output stored in + './mindflow/core/optimizers/data/adahessian_output.npy'. + Expectation: The output should match the target data within the defined relative tolerance, + ensuring the AdaHessian computation is accurate. + """ + ms.set_context(mode=mode) + + weight_init = ms.Tensor(np.reshape(range(72), [4, 2, 3, 3]), dtype=ms.float32) + bias_init = ms.Tensor(np.arange(4), dtype=ms.float32) + + net = nn.Conv2d( + in_channels=2, out_channels=4, kernel_size=3, has_bias=True, weight_init=weight_init, bias_init=bias_init) + + def forward(a): + return ops.sqrt(ops.mean(ops.square(net(a)))) + + grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) + + optimizer = TestAdaHessianAccuracy( + net.trainable_params(), + learning_rate=0.1, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) + + inputs = ms.Tensor(np.reshape(range(100), [2, 2, 5, 5]), dtype=ms.float32) + + for _ in range(4): + optimizer(grad_fn, inputs) + + outputs = net(inputs).numpy() + outputs_ref = np.load(os.path.join(test_data_path, 'adahessian_output.npy')) + relative_error = np.max(np.abs(outputs - outputs_ref)) / np.max(np.abs(outputs_ref)) + assert relative_error < FP32_RTOL, "The verification of adahessian accuracy is not successful." + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('model_option', ['unet', 'attention']) +def test_adahessian_st(mode, model_option): + """ + Feature: AdaHessian ST test + Description: Test the function of the AdaHessian optimizer in both GRAPH_MODE and PYNATIVE_MODE + on the complex network such as UNet. The input is a Tensor specified in the code + and the output is the loss after 4 rounds of optimization. + Expectation: The output should be finite, ensuring the AdaHessian runs successfully on UNet. + """ + ms.set_context(mode=mode) + + # default test with Attention network + net = TestAttentionBlock(in_channels=256, num_heads=4) + inputs = ms.Tensor(np.sin(np.arange(102400)).reshape(4, 100, 256), dtype=ms.float32) + + # test with UNet network + if model_option.lower() == 'unet': + net = TestUNet2D( + in_channels=2, + out_channels=4, + base_channels=8, + n_layers=4, + kernel_size=2, + stride=2, + activation='relu', + data_format="NCHW", + enable_bn=False, # bn leads to bug in PYNATIVE_MODE for MS2.5.0 + ) + inputs = ms.Tensor(np.random.rand(2, 2, 64, 64), dtype=ms.float32) + + def forward(a): + return ops.sqrt(ops.mean(ops.square(net(a)))) + + grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) + + optimizer = AdaHessian( + net.trainable_params(), + learning_rate=0.1, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) + + for _ in range(4): + optimizer(grad_fn, inputs) + + loss = forward(inputs) + assert ops.isfinite(loss) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +def test_adahessian_compare(mode): + """ + Feature: AdaHessian compare with Adam + Description: Compare the algorithm results of the AdaHessian optimizer with Adam. + The code runs in PYNATIVE_MODE and the network under comparison is TransformerBlock. + The optimization runs 100 rounds to demonstrate an essential loss decrease. + Expectation: The loss of AdaHessian outperforms Adam by 20% under the same configuration on an Attention network. + """ + ms.set_context(mode=mode) + + def get_loss(optimizer_option): + ''' compare Adam and AdaHessian ''' + net = TestAttentionBlock(in_channels=256, num_heads=4) + inputs = ms.Tensor(np.sin(np.arange(102400)).reshape(4, 100, 256), dtype=ms.float32) + + def forward(a): + return ops.sqrt(ops.mean(ops.square(net(a)))) + + grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) + + if optimizer_option.lower() == 'adam': + optimizer = nn.Adam( + net.trainable_params(), + learning_rate=0.01, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) + else: + optimizer = AdaHessian( + net.trainable_params(), + learning_rate=0.01, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) + + for _ in range(20): + if optimizer_option.lower() == 'adam': + optimizer(grad_fn(inputs)) + else: + optimizer(grad_fn, inputs) + + loss = forward(inputs) + return loss + + loss_adam = get_loss('adam') + loss_adahessian = get_loss('adahessian') + + assert loss_adam * 0.8 > loss_adahessian, (loss_adam, loss_adahessian) diff --git a/tests/models/ffno/test_ffno.py b/tests/models/ffno/test_ffno.py new file mode 100644 index 0000000000000000000000000000000000000000..13ca82cf225cf40f149c9c488e8fe61fb45a342d --- /dev/null +++ b/tests/models/ffno/test_ffno.py @@ -0,0 +1,380 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ffno testcase""" + +import os +import sys +import time + +import pytest +import numpy as np + +import mindspore as ms +from mindspore import nn, Tensor, set_seed, load_param_into_net, load_checkpoint +from mindspore import dtype as mstype + +from mindscience.models import FFNO1D, FFNO2D, FFNO3D + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(PROJECT_ROOT) + +# pylint: disable=wrong-import-position + +from tools import compare_output, FP32_RTOL + +# pylint: enable=wrong-import-position + +set_seed(123456) +folder_path = "/home/workspace/mindspore_dataset/mindscience/ffno" + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno1d_output(mode): + """ + Feature: Test FFNO1D network in platform ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model1d = FFNO1D(in_channels=2, + out_channels=2, + n_modes=[2], + resolutions=[6], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data1d = Tensor(np.load(os.path.join(folder_path, "ffno_data1d.npy")), dtype=mstype.float32) + param1d = load_checkpoint(os.path.join(folder_path, "ffno1d.ckpt")) + load_param_into_net(model1d, param1d) + output1d = model1d(data1d) + target1d = np.load(os.path.join(folder_path, "ffno_target1d.npy")) + + assert output1d.shape == (2, 6, 2) + assert output1d.dtype == mstype.float32 + assert compare_output(output1d.asnumpy(), target1d, rtol=FP32_RTOL, atol=FP32_RTOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno1d_mse_loss_output(mode): + """ + Feature: Test FFNO1D MSE Loss in platform ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model1d = FFNO1D(in_channels=2, + out_channels=2, + n_modes=[2], + resolutions=[6], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data1d = Tensor(np.ones((2, 6, 2)), dtype=mstype.float32) + label_1d = Tensor(np.ones((2, 6, 2)), dtype=mstype.float32) + param1d = load_checkpoint(os.path.join(folder_path, "ffno1d.ckpt")) + load_param_into_net(model1d, param1d) + + loss_fn = nn.MSELoss() + optimizer_1d = nn.SGD(model1d.trainable_params(), learning_rate=0.01) + net_with_loss_1d = nn.WithLossCell(model1d, loss_fn) + train_step_1d = nn.TrainOneStepCell(net_with_loss_1d, optimizer_1d) + + # calculate two steps of loss + loss_1d = train_step_1d(data1d, label_1d) + target_loss_1_1d = 0.63846040 + assert compare_output(loss_1d.asnumpy(), target_loss_1_1d, rtol=FP32_RTOL, atol=FP32_RTOL) + + loss_1d = train_step_1d(data1d, label_1d) + target_loss_2_1d = 0.04462930 + assert compare_output(loss_1d.asnumpy(), target_loss_2_1d, rtol=FP32_RTOL, atol=FP32_RTOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno2d_output(mode): + """ + Feature: Test FFNO2D network in platform ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model2d = FFNO2D(in_channels=2, + out_channels=2, + n_modes=[2, 2], + resolutions=[6, 6], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data2d = Tensor(np.load(os.path.join(folder_path, "ffno_data2d.npy")), dtype=mstype.float32) + param2d = load_checkpoint(os.path.join(folder_path, "ffno2d.ckpt")) + load_param_into_net(model2d, param2d) + output2d = model2d(data2d) + target2d = np.load(os.path.join(folder_path, "ffno_target2d.npy")) + + assert output2d.shape == (2, 6, 6, 2) + assert output2d.dtype == mstype.float32 + assert compare_output(output2d.asnumpy(), target2d, rtol=FP32_RTOL, atol=FP32_RTOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno2d_mse_loss_output(mode): + """ + Feature: Test FFNO2D MSE Loss in platform ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model2d = FFNO2D(in_channels=2, + out_channels=2, + n_modes=[2, 2], + resolutions=[6, 6], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data2d = Tensor(np.ones((2, 6, 6, 2)), dtype=mstype.float32) + label_2d = Tensor(np.ones((2, 6, 6, 2)), dtype=mstype.float32) + param2d = load_checkpoint(os.path.join(folder_path, "ffno2d.ckpt")) + load_param_into_net(model2d, param2d) + + loss_fn = nn.MSELoss() + optimizer_2d = nn.SGD(model2d.trainable_params(), learning_rate=0.01) + net_with_loss_2d = nn.WithLossCell(model2d, loss_fn) + train_step_2d = nn.TrainOneStepCell(net_with_loss_2d, optimizer_2d) + + # calculate two steps of loss + loss_2d = train_step_2d(data2d, label_2d) + target_loss_1_2d = 1.70347130 + assert compare_output(loss_2d.asnumpy(), target_loss_1_2d, rtol=FP32_RTOL, atol=FP32_RTOL) + + loss_2d = train_step_2d(data2d, label_2d) + target_loss_2_2d = 0.28143430 + assert compare_output(loss_2d.asnumpy(), target_loss_2_2d, rtol=FP32_RTOL, atol=FP32_RTOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno3d_output(mode): + """ + Feature: Test FFNO3D network in platform ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model3d = FFNO3D(in_channels=2, + out_channels=2, + n_modes=[2, 2, 2], + resolutions=[6, 6, 6], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data3d = Tensor(np.load(os.path.join(folder_path, "ffno_data3d.npy")), dtype=mstype.float32) + param3d = load_checkpoint(os.path.join(folder_path, "ffno3d.ckpt")) + load_param_into_net(model3d, param3d) + output3d = model3d(data3d) + target3d = np.load(os.path.join(folder_path, "ffno_target3d.npy")) + + assert output3d.shape == (2, 6, 6, 6, 2) + assert output3d.dtype == mstype.float32 + assert compare_output(output3d.asnumpy(), target3d, rtol=FP32_RTOL, atol=FP32_RTOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno3d_mse_loss_output(mode): + """ + Feature: Test FFNO3D MSE Loss in platform ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model3d = FFNO3D(in_channels=2, + out_channels=2, + n_modes=[2, 2, 2], + resolutions=[6, 6, 6], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data3d = Tensor(np.ones((2, 6, 6, 6, 2)), dtype=mstype.float32) + label_3d = Tensor(np.ones((2, 6, 6, 6, 2)), dtype=mstype.float32) + param3d = load_checkpoint(os.path.join(folder_path, "ffno3d.ckpt")) + load_param_into_net(model3d, param3d) + + loss_fn = nn.MSELoss() + optimizer_3d = nn.SGD(model3d.trainable_params(), learning_rate=0.01) + net_with_loss_3d = nn.WithLossCell(model3d, loss_fn) + train_step_3d = nn.TrainOneStepCell(net_with_loss_3d, optimizer_3d) + + # calculate two steps of loss + loss_3d = train_step_3d(data3d, label_3d) + target_loss_1_3d = 1.94374371 + assert compare_output(loss_3d.asnumpy(), target_loss_1_3d, rtol=FP32_RTOL, atol=FP32_RTOL) + + loss_3d = train_step_3d(data3d, label_3d) + target_loss_2_3d = 0.24034855 + assert compare_output(loss_3d.asnumpy(), target_loss_2_3d, rtol=FP32_RTOL, atol=FP32_RTOL) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno1d_speed(mode): + """ + Feature: Test FFNO1D training speed in platform ascend. + Description: The speed of each training step. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model1d = FFNO1D(in_channels=32, + out_channels=32, + n_modes=[16], + resolutions=[128], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data1d = Tensor(np.ones((32, 128, 32)), dtype=mstype.float32) + label_1d = Tensor(np.ones((32, 128, 32)), dtype=mstype.float32) + + loss_fn = nn.MSELoss() + optimizer_1d = nn.SGD(model1d.trainable_params(), learning_rate=0.01) + net_with_loss_1d = nn.WithLossCell(model1d, loss_fn) + train_step_1d = nn.TrainOneStepCell(net_with_loss_1d, optimizer_1d) + + steps = 10 + for _ in range(10): + train_step_1d(data1d, label_1d) + + start_time = time.time() + for _ in range(10): + train_step_1d(data1d, label_1d) + end_time = time.time() + + assert (end_time - start_time) / steps < 0.5 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno2d_speed(mode): + """ + Feature: Test FFNO2D training speed in platform ascend. + Description: The speed of each training step. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model2d = FFNO2D(in_channels=32, + out_channels=32, + n_modes=[16, 16], + resolutions=[64, 64], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data2d = Tensor(np.ones((32, 64, 64, 32)), dtype=mstype.float32) + label_2d = Tensor(np.ones((32, 64, 64, 32)), dtype=mstype.float32) + + loss_fn = nn.MSELoss() + optimizer_2d = nn.SGD(model2d.trainable_params(), learning_rate=0.01) + net_with_loss_2d = nn.WithLossCell(model2d, loss_fn) + train_step_2d = nn.TrainOneStepCell(net_with_loss_2d, optimizer_2d) + + steps = 10 + for _ in range(steps): + train_step_2d(data2d, label_2d) + + start_time = time.time() + for _ in range(steps): + train_step_2d(data2d, label_2d) + end_time = time.time() + + assert (end_time - start_time) / steps < 1 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno3d_speed(mode): + """ + Feature: Test FFNO3D training speed in platform ascend. + Description: The speed of each training step. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model3d = FFNO3D(in_channels=2, + out_channels=2, + n_modes=[16, 16, 16], + resolutions=[32, 32, 32], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data3d = Tensor(np.ones((2, 32, 32, 32, 2)), dtype=mstype.float32) + label_3d = Tensor(np.ones((2, 32, 32, 32, 2)), dtype=mstype.float32) + + loss_fn = nn.MSELoss() + optimizer_3d = nn.SGD(model3d.trainable_params(), learning_rate=0.01) + net_with_loss_3d = nn.WithLossCell(model3d, loss_fn) + train_step_3d = nn.TrainOneStepCell(net_with_loss_3d, optimizer_3d) + + steps = 10 + for _ in range(steps): + train_step_3d(data3d, label_3d) + + start_time = time.time() + for _ in range(steps): + train_step_3d(data3d, label_3d) + end_time = time.time() + + assert (end_time - start_time) / steps < 3 diff --git a/tests/models/fno/fno1d.yaml b/tests/models/fno/fno1d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..42533114fd93258b6834fc24fa72d6475244d054 --- /dev/null +++ b/tests/models/fno/fno1d.yaml @@ -0,0 +1,8 @@ +model: + name: FNO1D + in_channels: 1 + out_channels: 1 + modes: 16 + resolutions: 1024 + hidden_channels: 10 + depths: 1 diff --git a/tests/models/fno/test_fno.py b/tests/models/fno/test_fno.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb72593fe5597b510e7c80357512565440df857 --- /dev/null +++ b/tests/models/fno/test_fno.py @@ -0,0 +1,99 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""fno testcase""" + +import os +import pytest +import numpy as np + +from mindspore import Tensor, context, set_seed, load_param_into_net, load_checkpoint +from mindspore import dtype as mstype +from mindscience import FNO1D, FNO2D, FNO3D +from mindscience.models.neural_operator.fno_sp import SpectralConv1dDft, SpectralConv2dDft, SpectralConv3dDft + +RTOL = 0.001 +set_seed(123456) + +test_data_path = '/home/workspace/mindspore_dataset/mindscience/fno' + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_fno_output(): + """ + Feature: Test FNO1D, FNO2D and FNO3D network in platform gpu and ascend. + Description: None. + Expectation: Success or throw AssertionError. + Need to adaptive 910B + """ + context.set_context(mode=context.GRAPH_MODE) + model1d = FNO1D( + in_channels=2, out_channels=2, n_modes=[2], resolutions=[6], fno_compute_dtype=mstype.float32) + model2d = FNO2D( + in_channels=2, out_channels=2, n_modes=[2, 2], resolutions=[6, 6], fno_compute_dtype=mstype.float32) + model3d = FNO3D( + in_channels=2, out_channels=2, n_modes=[2, 2, 2], resolutions=[6, 6, 6], fno_compute_dtype=mstype.float32) + data1d = Tensor(np.ones((2, 6, 2)), dtype=mstype.float32) + data2d = Tensor(np.ones((2, 6, 6, 2)), dtype=mstype.float32) + data3d = Tensor(np.ones((2, 6, 6, 6, 2)), dtype=mstype.float32) + output1d = model1d(data1d) + output2d = model2d(data2d) + output3d = model3d(data3d) + assert output1d.shape == (2, 6, 2) + assert output1d.dtype == mstype.float32 + assert output2d.shape == (2, 6, 6, 2) + assert output2d.dtype == mstype.float32 + assert output3d.shape == (2, 6, 6, 6, 2) + assert output3d.dtype == mstype.float32 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_spectralconvdft_output(): + """ + Feature: Test SpectralConv1dDft, SpectralConv2dDft and SpectralConv3dDft network in platform gpu and ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + context.set_context(mode=context.GRAPH_MODE) + model1d = SpectralConv1dDft(in_channels=2, out_channels=2, n_modes=[2], resolutions=[6]) + model2d = SpectralConv2dDft(in_channels=2, out_channels=2, n_modes=[2, 2], resolutions=[6, 6]) + model3d = SpectralConv3dDft(in_channels=2, out_channels=2, n_modes=[2, 2, 2], resolutions=[6, 6, 6]) + data1d = Tensor(np.ones((2, 2, 6)), dtype=mstype.float32) + data2d = Tensor(np.ones((2, 2, 6, 6)), dtype=mstype.float32) + data3d = Tensor(np.ones((2, 2, 6, 6, 6)), dtype=mstype.float32) + target1d = 3.64671636 + target2d = 35.93239212 + target3d = 149.64256287 + param1 = load_checkpoint(os.path.join(test_data_path, "spectralconv1d.ckpt")) + param2 = load_checkpoint(os.path.join(test_data_path, "spectralconv2d.ckpt")) + param3 = load_checkpoint(os.path.join(test_data_path, "spectralconv3d.ckpt")) + load_param_into_net(model1d, param1) + load_param_into_net(model2d, param2) + load_param_into_net(model3d, param3) + output1d = model1d(data1d) + output2d = model2d(data2d) + output3d = model3d(data3d) + assert output1d.shape == (2, 2, 6) + assert output1d.dtype == mstype.float32 + assert output1d.sum() - target1d < RTOL + assert output2d.shape == (2, 2, 6, 6) + assert output2d.dtype == mstype.float32 + assert output2d.sum() - target2d < RTOL + assert output3d.shape == (2, 2, 6, 6, 6) + assert output3d.dtype == mstype.float32 + assert output3d.sum() - target3d < RTOL diff --git a/tests/models/fno/test_fno1d.py b/tests/models/fno/test_fno1d.py new file mode 100644 index 0000000000000000000000000000000000000000..7a7538db77c48bb54e732e60329948dd47ffff6c --- /dev/null +++ b/tests/models/fno/test_fno1d.py @@ -0,0 +1,233 @@ +# ============================================================================ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FNO1D Test Case""" +import os +import random +import sys + +import pytest +import numpy as np + +import mindspore as ms +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore import Tensor, ops, set_seed +from mindspore import dtype as mstype +from mindscience import FNO1D, RelativeRMSELoss, load_yaml_config +from mindscience.pde import SteadyFlowWithLoss + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(PROJECT_ROOT) + +# pylint: disable=wrong-import-position + +from tools import validate_checkpoint, compare_output, FP16_RTOL, FP16_ATOL + +# pylint: enable=wrong-import-position + +set_seed(0) +np.random.seed(0) +random.seed(0) + +test_data_path = '/home/workspace/mindspore_dataset/mindscience/fno1d' + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_fno1d_checkpoint(): + """ + Feature: FNO1D checkpoint loading and verification + Description: Test the consistency of the FNO1D model when loading from a saved checkpoint. + Two FNO1D models are initialized with the same parameters, and one of them + loads weights from the specified checkpoint located at './mindflow/cell/fno1d/ckpt/fno1d.ckpt'. + The test input is a randomly generated tensor, and the validation checks if + both models (one with loaded parameters) produce the same outputs. + Expectation: The model loaded from the checkpoint should behave identically to a newly initialized + model with the same parameters, verifying that the checkpoint restores the model's state correctly. + """ + config = load_yaml_config('fno1d.yaml') + model_params = config["model"] + ckpt_path = os.path.join(test_data_path, 'fno1d.ckpt') + + model1 = FNO1D(in_channels=model_params["in_channels"], + out_channels=model_params["out_channels"], + n_modes=model_params["modes"], + resolutions=model_params["resolutions"], + hidden_channels=model_params["hidden_channels"], + n_layers=model_params["depths"], + projection_channels=4*model_params["hidden_channels"], + ) + + model2 = FNO1D(in_channels=model_params["in_channels"], + out_channels=model_params["out_channels"], + n_modes=model_params["modes"], + resolutions=model_params["resolutions"], + hidden_channels=model_params["hidden_channels"], + n_layers=model_params["depths"], + projection_channels=4*model_params["hidden_channels"], + ) + + params = load_checkpoint(ckpt_path) + load_param_into_net(model1, params) + test_inputs = Tensor(np.random.randn(1, 1024, 1), mstype.float32) + + validate_ans = validate_checkpoint(model1, model2, (test_inputs,)) + assert validate_ans, "The verification of FNO1D checkpoint is not successful." + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_fno1d_forward_accuracy(mode): + """ + Feature: FNO1D forward accuracy test + Description: Test the forward accuracy of the FNO1D model in both GRAPH_MODE and PYNATIVE_MODE. + The model is initialized with parameters from './mindflow/cell/fno1d/configs/fno1d.yaml', + and weights are loaded from the checkpoint located at './mindflow/cell/fno1d/ckpt/fno1d.ckpt'. + The input data is loaded from './mindflow/cell/fno1d/data/fno1d_input.npy', and the output + is compared against the expected prediction stored in './mindflow/cell/fno1d/data/fno1d_pred.npy'. + Expectation: The output should match the target prediction data within the specified relative and absolute + tolerance values, ensuring the forward pass of the FNO1D model is accurate. + """ + ms.set_context(mode=mode) + config = load_yaml_config('fno1d.yaml') + model_params = config["model"] + ckpt_path = os.path.join(test_data_path, 'fno1d.ckpt') + + model = FNO1D(in_channels=model_params["in_channels"], + out_channels=model_params["out_channels"], + n_modes=model_params["modes"], + resolutions=model_params["resolutions"], + hidden_channels=model_params["hidden_channels"], + n_layers=model_params["depths"], + projection_channels=4*model_params["hidden_channels"], + ) + + params = load_checkpoint(ckpt_path) + load_param_into_net(model, params) + input_data = np.load(os.path.join(test_data_path, 'fno1d_input.npy')) + test_inputs = Tensor(input_data, mstype.float32) + output = model(test_inputs) + output = output.asnumpy() + output_target = np.load(os.path.join(test_data_path, 'fno1d_pred.npy')) + validate_ans = compare_output(output, output_target, rtol=FP16_RTOL, atol=FP16_ATOL) + assert validate_ans, "The verification of FNO1D forward accuracy is not successful." + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_fno1d_amp(): + """ + Feature: FNO1D AMP (Automatic Mixed Precision) accuracy test + Description: Test the accuracy of FNO1D model with and without AMP (Automatic Mixed Precision). + Two FNO1D models are initialized with identical parameters. The first model uses the + default precision(float16), while the second model is set to use float32 precision for computation. + Both models load the same checkpoint from './mindflow/cell/fno1d/ckpt/fno1d.ckpt'. + The input data is loaded from './mindflow/cell/fno1d/data/fno1d_input.npy', and outputs + of the two models are compared to check if they match within the specified tolerance. + Expectation: The outputs of both models (with and without AMP) should match within the defined + relative and absolute tolerance values, verifying that AMP does not affect the accuracy. + """ + config = load_yaml_config('fno1d.yaml') + model_params = config["model"] + ckpt_path = os.path.join(test_data_path, 'fno1d.ckpt') + + model1 = FNO1D(in_channels=model_params["in_channels"], + out_channels=model_params["out_channels"], + n_modes=model_params["modes"], + resolutions=model_params["resolutions"], + hidden_channels=model_params["hidden_channels"], + n_layers=model_params["depths"], + projection_channels=4*model_params["hidden_channels"], + ) + + model2 = FNO1D(in_channels=model_params["in_channels"], + out_channels=model_params["out_channels"], + n_modes=model_params["modes"], + resolutions=model_params["resolutions"], + hidden_channels=model_params["hidden_channels"], + n_layers=model_params["depths"], + projection_channels=4*model_params["hidden_channels"], + fno_compute_dtype=mstype.float32, + ) + + params = load_checkpoint(ckpt_path) + load_param_into_net(model1, params) + load_param_into_net(model2, params) + input_data = np.load(os.path.join(test_data_path, 'fno1d_input.npy')) + test_inputs = Tensor(input_data, mstype.float32) + output1 = model1(test_inputs) + output1 = output1.asnumpy() + output2 = model2(test_inputs) + output2 = output2.asnumpy() + validate_ans = compare_output(output1, output2, rtol=FP16_RTOL, atol=FP16_ATOL) + assert validate_ans, "The verification of FNO1D AMP accuracy is not successful." + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_fno1d_grad_accuracy(): + """ + Feature: FNO1D gradient accuracy test + Description: Test the accuracy of the computed gradients for the FNO1D model. The model is initialized + with parameters from './mindflow/cell/fno1d/configs/fno1d.yaml' and weights are loaded + from the checkpoint located at './mindflow/cell/fno1d/ckpt/fno1d.ckpt'. The loss function used + is RelativeRMSELoss. The input data is loaded from './mindflow/cell/fno1d/data/fno1d_input.npy' + and the label is from './mindflow/cell/fno1d/data/fno1d_input_label.npy'. Gradients are computed + using MindSpore's value_and_grad and compared against the reference gradients stored in + './mindflow/cell/fno1d/data/fno1d_grads.npz'. + Expectation: The computed gradients should match the reference gradients within the specified relative and + absolute tolerance values, ensuring the gradient calculation is accurate. + """ + config = load_yaml_config('fno1d.yaml') + model_params = config["model"] + ckpt_path = os.path.join(test_data_path, 'fno1d.ckpt') + + model = FNO1D(in_channels=model_params["in_channels"], + out_channels=model_params["out_channels"], + n_modes=model_params["modes"], + resolutions=model_params["resolutions"], + hidden_channels=model_params["hidden_channels"], + n_layers=model_params["depths"], + projection_channels=4*model_params["hidden_channels"], + ) + + params = load_checkpoint(ckpt_path) + load_param_into_net(model, params) + input_data = np.load(os.path.join(test_data_path, 'fno1d_input.npy')) + input_label = np.load(os.path.join(test_data_path, 'fno1d_input_label.npy')) + test_inputs = Tensor(input_data, mstype.float32) + test_label = Tensor(input_label, mstype.float32) + + problem = SteadyFlowWithLoss( + model, loss_fn=RelativeRMSELoss()) + + def forward_fn(data, label): + loss = problem.get_loss(data, label) + return loss + + grad_fn = ops.value_and_grad( + forward_fn, None, model.trainable_params(), has_aux=False) + + _, grads = grad_fn(test_inputs, test_label) + convert_grads = tuple(grad.asnumpy() for grad in grads) + with np.load(os.path.join(test_data_path, 'fno1d_grads.npz')) as data: + output_target = tuple(data[key] for key in data.files) + validate_ans = compare_output(convert_grads, output_target, rtol=FP16_RTOL, atol=FP16_ATOL) + assert validate_ans, "The verification of FNO1D grad accuracy is not successful." diff --git a/tests/sciops/test_fourier.py b/tests/sciops/test_fourier.py new file mode 100644 index 0000000000000000000000000000000000000000..221972edb9c81abd3e08c542d1fb2012606dfa5b --- /dev/null +++ b/tests/sciops/test_fourier.py @@ -0,0 +1,247 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Optimizers Test Case""" +import os +import random +import sys +from time import time as toc +import pytest +import numpy as np +from scipy.fft import dct, dst +import mindspore as ms +from mindspore import set_seed, ops +from mindscience import DFTn, IDFTn, RDFTn, IRDFTn, DCT, IDCT, DST, IDST + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) +sys.path.append(PROJECT_ROOT) + +# pylint: disable=wrong-import-position + +from tools import FP32_RTOL, FP16_RTOL, FP32_ATOL, FP16_ATOL, compare_output + +# pylint: enable=wrong-import-position + +set_seed(0) +np.random.seed(0) +random.seed(0) + + +def gen_input(shape=(5, 6, 4, 8), rand_test=True): + ''' Generate random or deterministic tensor for input of the tests + ''' + a = np.random.rand(*shape) + 1j * np.random.rand(*shape) + if not rand_test: + a = sum([np.arange(n).reshape([n] + [1] * j) for j, n in enumerate(shape[::-1])]) + 1j * \ + sum([np.arange(n).reshape([n] + [1] * j) for j, n in enumerate(shape[::-1])]) + ar, ai = (ms.Tensor(a.real, dtype=ms.float32), ms.Tensor(a.imag, dtype=ms.float32)) + return a, ar, ai + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2, 3]) +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_rdft_accuracy(device_target, mode, ndim, compute_dtype): + """ + Feature: Test RDFTn & IRDFTn accuracy + Description: Input random tensor, compare the results of RDFTn and IRDFTn with numpy results + Expectation: The output tensors should be equal within tolerance + """ + ms.set_context(device_target=device_target, mode=mode) + a, ar, _ = gen_input() + shape = a.shape + + b = np.fft.rfftn(a.real, s=a.shape[-ndim:], axes=range(-ndim, 0)) + br, bi = RDFTn(shape[-ndim:], compute_dtype=compute_dtype)(ar) + cr = IRDFTn(shape[-ndim:], compute_dtype=compute_dtype)(br, bi) + + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20 + + assert compare_output(br.numpy(), b.real, rtol, atol) + assert compare_output(bi.numpy(), b.imag, rtol, atol) + assert compare_output(cr.numpy(), a.real, rtol, atol) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2, 3]) +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_dft_accuracy(device_target, mode, ndim, compute_dtype): + """ + Feature: Test DFTn & IDFTn accuracy + Description: Input random tensor, compare the results of DFTn and IDFTn with numpy results + Expectation: The output tensors should be equal within tolerance + """ + ms.set_context(device_target=device_target, mode=mode) + a, ar, ai = gen_input() + shape = a.shape + + b = np.fft.fftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0)) + br, bi = DFTn(shape[-ndim:], compute_dtype=compute_dtype)(ar, ai) + cr, ci = IDFTn(shape[-ndim:], compute_dtype=compute_dtype)(br, bi) + + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20 + + assert compare_output(br.numpy(), b.real, rtol, atol) + assert compare_output(bi.numpy(), b.imag, rtol, atol) + assert compare_output(cr.numpy(), a.real, rtol, atol) + assert compare_output(ci.numpy(), a.imag, rtol, atol) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_dct_accuracy(device_target, mode, compute_dtype): + """ + Feature: Test DCT & IDCT accuracy + Description: Input random tensor, compare the results of DCT and IDCT with numpy results + Expectation: The output tensors should be equal within tolerance + """ + ms.set_context(device_target=device_target, mode=mode) + a, ar, _ = gen_input() + shape = a.shape + + b = dct(a.real) + br = DCT(shape[-1:], compute_dtype=compute_dtype)(ar) + cr = IDCT(shape[-1:], compute_dtype=compute_dtype)(br) + + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20 + + assert compare_output(br.numpy(), b.real, rtol, atol) + assert compare_output(cr.numpy(), a.real, rtol, atol) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_dst_accuracy(device_target, mode, compute_dtype): + """ + Feature: Test DST & IDST accuracy + Description: Input random tensor, compare the results of DST and IDST with numpy results + Expectation: The output tensors should be equal within tolerance + """ + ms.set_context(device_target=device_target, mode=mode) + a, ar, _ = gen_input() + shape = a.shape + + b = dst(a.real) + br = DST(shape[-1:], compute_dtype=compute_dtype)(ar) + cr = IDST(shape[-1:], compute_dtype=compute_dtype)(br) + + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20 + + assert compare_output(br.numpy(), b.real, rtol, atol) + assert compare_output(cr.numpy(), a.real, rtol, atol) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['Ascend']) +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2, 3]) +def test_dft_speed(device_target, mode, ndim): + """ + Feature: Test DFTn & IDFTn speed + Description: Input random tensor, clock the time of 10 runs of the + gradient function containing DFT & iDFT operators + Expectation: The average time of each run should be within 0.5s + """ + # test dftn & idftn speed + ms.set_context(device_target=device_target, mode=mode) + a, ar, ai = gen_input(shape=(64, 128, 256)) + shape = a.shape + + warmup_steps = 10 + timed_steps = 10 + + dft_cell = DFTn(shape[-ndim:]) + idft_cell = IDFTn(shape[-ndim:]) + + def forward_fn(xr, xi): + br, bi = dft_cell(xr, xi) + cr, ci = idft_cell(br, bi) + return ops.sum(cr * cr + ci * ci) + + grad_fn = ms.value_and_grad(forward_fn, grad_position=(0, 1)) + + # warmup run + for _ in range(warmup_steps): + _, (g1, g2) = grad_fn(ar, ai) + ar = ar - .1 * g1 + ai = ai - .1 * g2 + + # timed run + tic = toc() + for _ in range(timed_steps): + _, (g1, g2) = grad_fn(ar, ai) + ar = ar - .1 * g1 + ai = ai - .1 * g2 + + assert (toc() - tic) / timed_steps < 0.5 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2, 3]) +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_dft_grad(device_target, mode, ndim, compute_dtype): + """ + Feature: Test the correctness of DFTn & IDFTn grad calculation + Description: Input random tensor, compare the autograd results with theoretic solutions + Expectation: The autograd results should be equal to theoretic solutions + """ + ms.set_context(device_target=device_target, mode=mode) + a, ar, ai = gen_input() + shape = a.shape + + dft_cell = DFTn(shape[-ndim:], compute_dtype=compute_dtype) + + def forward_fn(xr, xi): + yr, yi = dft_cell(xr, xi) + return ops.sum(yr * yr + yi * yi) + + grad_fn = ms.value_and_grad(forward_fn, grad_position=(0, 1)) + _, (g1, g2) = grad_fn(ar, ai) + + # analytic solution of the gradient + b = np.fft.fftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0)) + g = np.fft.ifftn(b, s=a.shape[-ndim:], axes=range(-ndim, 0)) * 2 * np.prod(a.shape[-ndim:]) + + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 500 # grad func leads to larger error + + assert compare_output(g1.numpy(), g.real, rtol, atol) + assert compare_output(g2.numpy(), g.imag, rtol, atol)