diff --git a/MindFlow/applications/cfd/acoustic/cbs/cbs.py b/MindFlow/applications/cfd/acoustic/cbs/cbs.py index 0aba76a0a21787002931c38477508bea5a0d340c..7496a8997482ed1e923352db325485ef4f0f3b5b 100644 --- a/MindFlow/applications/cfd/acoustic/cbs/cbs.py +++ b/MindFlow/applications/cfd/acoustic/cbs/cbs.py @@ -19,7 +19,7 @@ import numpy as np import mindspore as ms from mindspore import Tensor, nn, ops, numpy as mnp, lazy_inline -from .dft import MyDFTn, MyiDFTn +from mindflow import DFTn, IDFTn class CBSBlock(nn.Cell): @@ -32,8 +32,8 @@ class CBSBlock(nn.Cell): shape: tuple of int, only the spatial shape, not including the batch and channel dimensions ''' super().__init__() - self.dft_cell = MyDFTn(shape) - self.idft_cell = MyiDFTn(shape) + self.dft_cell = DFTn(shape) + self.idft_cell = IDFTn(shape) # Scattering potential calculation for real and imaginary parts def op_v(self, ur, ui, vr, vi): diff --git a/MindFlow/applications/cfd/acoustic/cbs/dft.py b/MindFlow/applications/cfd/acoustic/cbs/dft.py deleted file mode 100644 index d93d9a8b16e1b970ee1f7c6f08e9ee831bdffbf7..0000000000000000000000000000000000000000 --- a/MindFlow/applications/cfd/acoustic/cbs/dft.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2025 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -''' provide complex dft based on the real dft API in mindflow.dft ''' -import numpy as np -import mindspore as ms -from mindspore import nn, ops, numpy as mnp, mint -from mindflow.cell.neural_operators.dft import dft1, dft2, dft3 - - -class MyDFTn(nn.Cell): - def __init__(self, shape): - super().__init__() - assert len(shape) in (1, 2, 3), 'only ndim 1, 2, 3 supported' - - n = shape[-1] - ndim = len(shape) - modes = tuple([_ // 2 for _ in shape[-ndim:-1]] + [n // 2 + 1]) if ndim > 1 else n // 2 + 1 - - self.shape = tuple(shape) - self.dft_cell = { - 1: dft1, - 2: dft2, - 3: dft3, - }[ndim](shape, modes) - - # use mask to assemble slices of Tensors, avoiding dynamic shape - # bug note: for unknown reasons, GRAPH_MODE cannot work with mask Tensors allocated using ops.ones() - mask_x0 = np.ones(n//2 + 1) - mask_xm = np.ones(n//2 + 1) - mask_y0 = np.ones(shape) - mask_z0 = np.ones(shape) - mask_x0[0] = 0 - mask_xm[-1] = 0 - if ndim > 1: - mask_y0[..., 0, :] = 0 - if ndim > 2: - mask_z0[..., 0, :, :] = 0 - - self.mask_x0 = ms.Tensor(mask_x0, dtype=ms.float32, const_arg=True) - self.mask_xm = ms.Tensor(mask_xm, dtype=ms.float32, const_arg=True) - self.mask_y0 = ms.Tensor(mask_y0, dtype=ms.float32, const_arg=True) - self.mask_z0 = ms.Tensor(mask_z0, dtype=ms.float32, const_arg=True) - - # bug note: ops.flip/mint.flip/mint.roll has bug for MS2.4.0 in PYNATIVE_MODE - # mnp.flip has bug after MS2.4.0 in GRAPH_MODE - # ops.roll only supports GPU, mnp.roll is ok but slow - msver = tuple([int(s) for s in ms.__version__.split('.')]) - kwargs1 = (dict(axis=-1), dict(axis=-2), dict(axis=-3)) - kwargs2 = (dict(dims=(-1,)), dict(dims=(-2,)), dict(dims=(-3,))) - - if msver <= (2, 4, 0) and ms.get_context('mode') == ms.PYNATIVE_MODE: - self.fliper = mnp.flip - self.roller = mnp.roll - self.flipkw = kwargs1 - self.rollkw = kwargs1 - else: - self.fliper = mint.flip - self.roller = mint.roll - self.flipkw = kwargs2 - self.rollkw = kwargs2 - - def construct(self, ar, ai): - shape = tuple(self.shape) - n = shape[-1] - ndim = len(shape) - scale = float(np.prod(shape) ** .5) - - assert ai is None or ar.shape == ai.shape - assert ar.shape[-ndim:] == shape - - brr, bri = self.dft_cell((ar, ar * 0)) - - # n-D Fourier transform with last axis being real-transformed, output dimension (..., m, n//2+1) - if ai is None: - return brr * scale, bri * scale - - # n-D complex Fourier transform, output dimension (..., m, n) - # call dft for real & imag parts separately and then assemble - bir, bii = self.dft_cell((ai, ai * 0)) - - br_half1 = ops.pad((brr - bii) * self.mask_xm, [0, n//2 - 1]) - bi_half1 = ops.pad((bri + bir) * self.mask_xm, [0, n//2 - 1]) - - br_half2 = self.roller(self.fliper( - ops.pad((brr + bii) * self.mask_x0, [n//2 - 1, 0]), **self.flipkw[0]), n//2, **self.rollkw[0]) - bi_half2 = self.roller(self.fliper( - ops.pad((bir - bri) * self.mask_x0, [n//2 - 1, 0]), **self.flipkw[0]), n//2, **self.rollkw[0]) - if ndim > 1: - br_half2 = br_half2 * (1 - self.mask_y0) + self.roller(self.fliper( - br_half2 * self.mask_y0, **self.flipkw[1]), 1, **self.rollkw[1]) - bi_half2 = bi_half2 * (1 - self.mask_y0) + self.roller(self.fliper( - bi_half2 * self.mask_y0, **self.flipkw[1]), 1, **self.rollkw[1]) - if ndim > 2: - br_half2 = br_half2 * (1 - self.mask_z0) + self.roller(self.fliper( - br_half2 * self.mask_z0, **self.flipkw[2]), 1, **self.rollkw[2]) - bi_half2 = bi_half2 * (1 - self.mask_z0) + self.roller(self.fliper( - bi_half2 * self.mask_z0, **self.flipkw[2]), 1, **self.rollkw[2]) - - br = br_half1 + br_half2 - bi = bi_half1 + bi_half2 - - return br * scale, bi * scale - -class MyiDFTn(MyDFTn): - def __init__(self, shape): - super().__init__(shape) - - def construct(self, ar, ai): - ndim = len(self.shape) - scale = float(np.prod(ar.shape[-ndim:])) - br, bi = super().construct(ar, -ai) - return br / scale, -bi / scale diff --git a/MindFlow/applications/data_driven/airfoil/2D_unsteady/src/fno2d.py b/MindFlow/applications/data_driven/airfoil/2D_unsteady/src/fno2d.py index 64d0e055130488358cc4fe721af74fc72c24262e..8cd1d90be33c5977d52996c374a574f2fb7f9408 100644 --- a/MindFlow/applications/data_driven/airfoil/2D_unsteady/src/fno2d.py +++ b/MindFlow/applications/data_driven/airfoil/2D_unsteady/src/fno2d.py @@ -23,7 +23,7 @@ from mindspore.common.initializer import Zero from mindflow.utils.check_func import check_param_type from mindflow.core.math import get_grid_2d -from mindflow.cell.neural_operators.dft import dft2, idft2 +from mindflow import RDFTn, IRDFTn class FNO2D(nn.Cell): @@ -177,10 +177,10 @@ class SpectralConv2dDft(nn.Cell): self.w_im1 = Parameter(w_im1, requires_grad=True) self.w_re2 = Parameter(w_re2, requires_grad=True) self.w_im2 = Parameter(w_im2, requires_grad=True) - self.dft2_cell = dft2(shape=(column_resolution, raw_resolution), - modes=(modes1, modes2), compute_dtype=self.compute_dtype) - self.idft2_cell = idft2(shape=(column_resolution, raw_resolution), - modes=(modes1, modes2), compute_dtype=self.compute_dtype) + self.dft2_cell = RDFTn(shape=(column_resolution, raw_resolution), norm='ortho', + modes=(modes1, modes2), compute_dtype=self.compute_dtype) + self.idft2_cell = IRDFTn(shape=(column_resolution, raw_resolution), norm='ortho', + modes=(modes1, modes2), compute_dtype=self.compute_dtype) self.mat = Tensor(shape=(1, out_channels, column_resolution - 2 * modes1, modes2), dtype=self.compute_dtype, init=Zero()) self.concat = ops.Concat(-2) @@ -195,8 +195,7 @@ class SpectralConv2dDft(nn.Cell): def construct(self, x: Tensor): """forward""" x_re = x - x_im = ops.zeros_like(x_re) - x_ft_re, x_ft_im = self.dft2_cell((x_re, x_im)) + x_ft_re, x_ft_im = self.dft2_cell(x_re) out_ft_re1 = \ self.mul2d(x_ft_re[:, :, :self.modes1, :self.modes2], self.w_re1) \ @@ -217,5 +216,5 @@ class SpectralConv2dDft(nn.Cell): out_re = self.concat((out_ft_re1, mat, out_ft_re2)) out_im = self.concat((out_ft_im1, mat, out_ft_im2)) - x, _ = self.idft2_cell((out_re, out_im)) + x = self.idft2_cell(out_re, out_im) return x diff --git a/MindFlow/mindflow/cell/neural_operators/ffno.py b/MindFlow/mindflow/cell/neural_operators/ffno.py index 8d8fe66944001e41ac1bc99a5f243637b8116dc0..be22763c34b691723a5e3af21d48447f94efc047 100644 --- a/MindFlow/mindflow/cell/neural_operators/ffno.py +++ b/MindFlow/mindflow/cell/neural_operators/ffno.py @@ -339,7 +339,7 @@ class FFNO(nn.Cell): self.dft_compute_dtype = dft_compute_dtype self.ffno_compute_dtype = ffno_compute_dtype self._concat = ops.Concat(axis=-1) - self._positional_embedding, self._input_perm, self._output_perm = self._transpose(len(self.resolutions)) + self._positional_embedding = self._transpose(len(self.resolutions)) self._padding = self._pad(len(self.resolutions)) if self.lifting_channels: self._lifting = nn.SequentialCell([ @@ -401,57 +401,50 @@ class FFNO(nn.Cell): """construct""" batch_size = x.shape[0] grid = mint.repeat_interleave(self._positional_embedding.astype(x.dtype), repeats=batch_size, dim=0) + if self.data_format != "channels_last": - x = ops.transpose(x, input_perm=self._output_perm) + x = ops.movedim(x, 1, -1) + if self.positional_embedding: x = self._concat((x, grid)) x = self._lifting(x) - x = ops.transpose(x, input_perm=self._input_perm) if self.r_padding != 0: - x = ops.Pad(self._padding)(x) - - x = ops.transpose(x, input_perm=self._output_perm) + x = ops.movedim(x, -1, 1) + x = ops.pad(x, self._padding) + x = ops.movedim(x, 1, -1) b = Tensor(0, dtype=mstype.float32) for block in self._ffno_blocks: x, b = block(x) + if self.r_padding != 0: b = self._remove_padding(len(self.resolutions), b) + x = self._projection(b) + if self.data_format != "channels_last": - x = ops.transpose(x, input_perm=self._input_perm) + x = ops.movedim(x, -1, 1) + return x def _transpose(self, n_dim): """transpose tensor""" if n_dim == 1: positional_embedding = Tensor(get_grid_1d(resolution=self.resolutions)) - input_perm = (0, 2, 1) - output_perm = (0, 2, 1) elif n_dim == 2: positional_embedding = Tensor(get_grid_2d(resolution=self.resolutions)) - input_perm = (0, 3, 1, 2) - output_perm = (0, 2, 3, 1) elif n_dim == 3: positional_embedding = Tensor(get_grid_3d(resolution=self.resolutions)) - input_perm = (0, 4, 1, 2, 3) - output_perm = (0, 2, 3, 4, 1) else: raise ValueError(f"The length of input resolutions dimensions should be in [1, 2, 3], but got: {n_dim}") - return positional_embedding, input_perm, output_perm + return positional_embedding def _pad(self, n_dim): """pad the domain if input is non-periodic""" - if n_dim == 1: - pad = ([0, 0], [0, 0], [0, self.r_padding]) - elif n_dim == 2: - pad = ([0, 0], [0, 0], [0, self.r_padding], [0, self.r_padding]) - elif n_dim == 3: - pad = ([0, 0], [0, 0], [0, self.r_padding], [0, self.r_padding], [0, self.r_padding]) - else: + if not n_dim in {1, 2, 3}: raise ValueError(f"The length of input resolutions dimensions should be in [1, 2, 3], but got: {n_dim}") - return pad + return n_dim * [0, self.r_padding] def _remove_padding(self, n_dim, b_input): """remove pad domain""" diff --git a/MindFlow/mindflow/cell/neural_operators/ffno_sp.py b/MindFlow/mindflow/cell/neural_operators/ffno_sp.py index 8ad65613cc3f63a0119a8cada9dd5526e78ef426..b1fa1382fd96061f05209937c8c79976c17a1448 100644 --- a/MindFlow/mindflow/cell/neural_operators/ffno_sp.py +++ b/MindFlow/mindflow/cell/neural_operators/ffno_sp.py @@ -19,7 +19,7 @@ import mindspore.common.dtype as mstype from mindspore import nn, ops, Tensor, Parameter, ParameterTuple, mint from mindspore.common.initializer import XavierNormal, initializer from ...core.math import get_grid_1d, get_grid_2d, get_grid_3d -from .dft import dft1, idft1 +from ...core.fourier import RDFTn, IRDFTn class FeedForward(nn.Cell): @@ -106,8 +106,8 @@ class SpectralConv(nn.Cell): """" n- shape - 3D: S1 S2 S3 / 2D: M N / 1D: C mode - output length - n//2 +1 dim - 3D: -1 -2 -3 / 2D: -1 -2 / 1D: -1 """ - dft_cell = dft1(shape=(n,), modes=mode, dim=(n_dim,), compute_dtype=self.compute_dtype) - idft_cell = idft1(shape=(n,), modes=mode, dim=(n_dim,), compute_dtype=self.compute_dtype) + dft_cell = RDFTn(shape=n, dim=n_dim, norm='ortho', modes=mode, compute_dtype=self.compute_dtype) + idft_cell = IRDFTn(shape=n, dim=n_dim, norm='ortho', modes=mode, compute_dtype=self.compute_dtype) return dft_cell, idft_cell @@ -244,9 +244,8 @@ class SpectralConv1d(SpectralConv): x = ops.transpose(x, input_perm=self._output_perm) # x shape: batch, in_dim, grid_size x_ft_re = x - x_ft_im = ops.zeros_like(x_ft_re) - x_ftx_re, x_ftx_im = self._dft1_x_cell((x_ft_re, x_ft_im)) + x_ftx_re, x_ftx_im = self._dft1_x_cell(x_ft_re) x_ftx_re_part = x_ftx_re[:, :, :self.n_modes[0]] x_ftx_im_part = x_ftx_im[:, :, :self.n_modes[0]] @@ -268,7 +267,7 @@ class SpectralConv1d(SpectralConv): out_ftx_re = ops.zeros_like(x_ftx_re) out_ftx_im = ops.zeros_like(x_ftx_im) - x, _ = self._idft1_x_cell((out_ftx_re, out_ftx_im)) + x = self._idft1_x_cell(out_ftx_re, out_ftx_im) x = ops.transpose(x, input_perm=self._input_perm) return x @@ -298,10 +297,9 @@ class SpectralConv2d(SpectralConv): x = ops.transpose(x, input_perm=self._output_perm) # x shape: batch, in_dim, grid_size, grid_size x_ft_re = x - x_ft_im = ops.zeros_like(x_ft_re) # Dimesion Y - x_fty_re, x_fty_im = self._dft1_y_cell((x_ft_re, x_ft_im)) + x_fty_re, x_fty_im = self._dft1_y_cell(x_ft_re) x_fty_re_part = x_fty_re[:, :, :, :self.n_modes[1]] x_fty_im_part = x_fty_im[:, :, :, :self.n_modes[1]] @@ -323,10 +321,10 @@ class SpectralConv2d(SpectralConv): out_fty_re = ops.zeros_like(x_fty_re) out_fty_im = ops.zeros_like(x_fty_im) - xy, _ = self._idft1_y_cell((out_fty_re, out_fty_im)) + xy = self._idft1_y_cell(out_fty_re, out_fty_im) # Dimesion X - x_ftx_re, x_ftx_im = self._dft1_x_cell((x_ft_re, x_ft_im)) + x_ftx_re, x_ftx_im = self._dft1_x_cell(x_ft_re) x_ftx_re_part = x_ftx_re[:, :, :self.n_modes[0], :] x_ftx_im_part = x_ftx_im[:, :, :self.n_modes[0], :] @@ -348,7 +346,7 @@ class SpectralConv2d(SpectralConv): out_ftx_re = ops.zeros_like(x_ftx_re) out_ftx_im = ops.zeros_like(x_ftx_im) - xx, _ = self._idft1_x_cell((out_ftx_re, out_ftx_im)) + xx = self._idft1_x_cell(out_ftx_re, out_ftx_im) # Combining Dimensions x = xx + xy @@ -383,10 +381,9 @@ class SpectralConv3d(SpectralConv): x = ops.transpose(x, input_perm=self._output_perm) # x shape: batch, in_dim, grid_size, grid_size, grid_size x_ft_re = x - x_ft_im = ops.zeros_like(x_ft_re) # Dimesion Z - x_ftz_re, x_ftz_im = self._dft1_z_cell((x_ft_re, x_ft_im)) + x_ftz_re, x_ftz_im = self._dft1_z_cell(x_ft_re) x_ftz_re_part = x_ftz_re[:, :, :, :, :self.n_modes[2]] x_ftz_im_part = x_ftz_im[:, :, :, :, :self.n_modes[2]] @@ -408,10 +405,10 @@ class SpectralConv3d(SpectralConv): out_ftz_re = ops.zeros_like(x_ftz_re) out_ftz_im = ops.zeros_like(x_ftz_im) - xz, _ = self._idft1_z_cell((out_ftz_re, out_ftz_im)) + xz = self._idft1_z_cell(out_ftz_re, out_ftz_im) # Dimesion Y - x_fty_re, x_fty_im = self._dft1_y_cell((x_ft_re, x_ft_im)) + x_fty_re, x_fty_im = self._dft1_y_cell(x_ft_re) x_fty_re_part = x_fty_re[:, :, :, :self.n_modes[1], :] x_fty_im_part = x_fty_im[:, :, :, :self.n_modes[1], :] @@ -433,10 +430,10 @@ class SpectralConv3d(SpectralConv): out_fty_re = ops.zeros_like(x_fty_re) out_fty_im = ops.zeros_like(x_fty_im) - xy, _ = self._idft1_y_cell((out_fty_re, out_fty_im)) + xy = self._idft1_y_cell(out_fty_re, out_fty_im) # Dimesion X - x_ftx_re, x_ftx_im = self._dft1_x_cell((x_ft_re, x_ft_im)) + x_ftx_re, x_ftx_im = self._dft1_x_cell(x_ft_re) x_ftx_re_part = x_ftx_re[:, :, :self.n_modes[0], :, :] x_ftx_im_part = x_ftx_im[:, :, :self.n_modes[0], :, :] @@ -458,7 +455,7 @@ class SpectralConv3d(SpectralConv): out_ftx_re = ops.zeros_like(x_ftx_re) out_ftx_im = ops.zeros_like(x_ftx_im) - xx, _ = self._idft1_x_cell((out_ftx_re, out_ftx_im)) + xx = self._idft1_x_cell(out_ftx_re, out_ftx_im) # Combining Dimensions x = xx + xy + xz diff --git a/MindFlow/mindflow/cell/neural_operators/fno.py b/MindFlow/mindflow/cell/neural_operators/fno.py index 8eda5b682ec4aa40212533eef6f3eaff9d754829..4c4c644ac94e06d71fa68b3593bea8854d8c545f 100644 --- a/MindFlow/mindflow/cell/neural_operators/fno.py +++ b/MindFlow/mindflow/cell/neural_operators/fno.py @@ -19,7 +19,7 @@ from mindspore import nn, ops, Tensor, mint import mindspore.common.dtype as mstype -from .dft import SpectralConv1dDft, SpectralConv2dDft, SpectralConv3dDft +from .fno_sp import SpectralConv1dDft, SpectralConv2dDft, SpectralConv3dDft from ..activation import get_activation from ...core.math import get_grid_1d, get_grid_2d, get_grid_3d from ...utils.check_func import check_param_type diff --git a/MindFlow/mindflow/cell/neural_operators/fno_sp.py b/MindFlow/mindflow/cell/neural_operators/fno_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..bb02333507dc94592727b632597b22d661b6b7d5 --- /dev/null +++ b/MindFlow/mindflow/cell/neural_operators/fno_sp.py @@ -0,0 +1,243 @@ +'''' +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore import nn, ops, Tensor, Parameter, mint +from mindspore.common.initializer import Zero +from mindspore.ops import operations as P + +from ...core.fourier import RDFTn, IRDFTn + + +class SpectralConvDft(nn.Cell): + """Base Class for Fourier Layer, including DFT, linear transform, and Inverse DFT""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + if isinstance(n_modes, int): + n_modes = [n_modes] + self.n_modes = n_modes + if isinstance(resolutions, int): + resolutions = [resolutions] + self.resolutions = resolutions + if len(self.n_modes) != len(self.resolutions): + raise ValueError( + "The dimension of n_modes should be equal to that of resolutions, \ + but got dimension of n_modes {} and dimension of resolutions {}".format(len(self.n_modes), + len(self.resolutions))) + self.compute_dtype = compute_dtype + + def construct(self, x: Tensor): + raise NotImplementedError() + + def _einsum(self, inputs, weights): + weights = weights.expand_dims(0) + inputs = inputs.expand_dims(2) + out = inputs * weights + return out.sum(1) + + +class SpectralConv1dDft(SpectralConvDft): + """1D Fourier Layer. It does DFT, linear transform, and Inverse DFT.""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): + super().__init__(in_channels, out_channels, n_modes, resolutions) + self._scale = (1. / (self.in_channels * self.out_channels)) + w_re = Tensor(self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0]), + dtype=mstype.float32) + w_im = Tensor(self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0]), + dtype=mstype.float32) + self._w_re = Parameter(w_re, requires_grad=True) + self._w_im = Parameter(w_im, requires_grad=True) + self._dft1_cell = RDFTn( + shape=(self.resolutions[0],), norm='ortho', modes=self.n_modes[0], compute_dtype=self.compute_dtype) + self._idft1_cell = IRDFTn( + shape=(self.resolutions[0],), norm='ortho', modes=self.n_modes[0], compute_dtype=self.compute_dtype) + + def construct(self, x: Tensor): + x_re = x + x_ft_re, x_ft_im = self._dft1_cell(x_re) + w_re = P.Cast()(self._w_re, self.compute_dtype) + w_im = P.Cast()(self._w_im, self.compute_dtype) + out_ft_re = self._einsum(x_ft_re[:, :, :self.n_modes[0]], w_re) - self._einsum(x_ft_im[:, :, :self.n_modes[0]], + w_im) + out_ft_im = self._einsum(x_ft_re[:, :, :self.n_modes[0]], w_im) + self._einsum(x_ft_im[:, :, :self.n_modes[0]], + w_re) + + x = self._idft1_cell(out_ft_re, out_ft_im) + + return x + + +class SpectralConv2dDft(SpectralConvDft): + """2D Fourier Layer. It does DFT, linear transform, and Inverse DFT.""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): + super().__init__(in_channels, out_channels, n_modes, resolutions) + self._scale = (1. / (self.in_channels * self.out_channels)) + w_re1 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype) + w_im1 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype) + w_re2 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype) + w_im2 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype) + + self._w_re1 = Parameter(w_re1, requires_grad=True) + self._w_im1 = Parameter(w_im1, requires_grad=True) + self._w_re2 = Parameter(w_re2, requires_grad=True) + self._w_im2 = Parameter(w_im2, requires_grad=True) + + self._dft2_cell = RDFTn(shape=(self.resolutions[0], self.resolutions[1]), norm='ortho', + modes=(self.n_modes[0], self.n_modes[1]), compute_dtype=self.compute_dtype) + self._idft2_cell = IRDFTn(shape=(self.resolutions[0], self.resolutions[1]), norm='ortho', + modes=(self.n_modes[0], self.n_modes[1]), compute_dtype=self.compute_dtype) + self._mat = Tensor(shape=(1, self.out_channels, self.resolutions[1] - 2 * self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype, init=Zero()) + self._concat = ops.Concat(-2) + + def construct(self, x: Tensor): + x_re = x + x_ft_re, x_ft_im = self._dft2_cell(x_re) + + out_ft_re1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1) - self._einsum( + x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1) + out_ft_im1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1) + self._einsum( + x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1) + + out_ft_re2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2) - self._einsum( + x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2) + out_ft_im2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2) + self._einsum( + x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2) + + batch_size = x.shape[0] + mat = mint.repeat_interleave(self._mat, batch_size, 0) + out_re = self._concat((out_ft_re1, mat, out_ft_re2)) + out_im = self._concat((out_ft_im1, mat, out_ft_im2)) + + x = self._idft2_cell(out_re, out_im) + + return x + + +class SpectralConv3dDft(SpectralConvDft): + """3D Fourier layer. It does DFT, linear transform, and Inverse DFT.""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): + super().__init__(in_channels, out_channels, n_modes, resolutions) + self._scale = (1 / (self.in_channels * self.out_channels)) + + w_re1 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_im1 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_re2 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_im2 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_re3 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_im3 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_re4 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_im4 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + + self._w_re1 = Parameter(w_re1, requires_grad=True) + self._w_im1 = Parameter(w_im1, requires_grad=True) + self._w_re2 = Parameter(w_re2, requires_grad=True) + self._w_im2 = Parameter(w_im2, requires_grad=True) + self._w_re3 = Parameter(w_re3, requires_grad=True) + self._w_im3 = Parameter(w_im3, requires_grad=True) + self._w_re4 = Parameter(w_re4, requires_grad=True) + self._w_im4 = Parameter(w_im4, requires_grad=True) + + self._dft3_cell = RDFTn(shape=(self.resolutions[0], self.resolutions[1], self.resolutions[2]), norm='ortho', + modes=(self.n_modes[0], self.n_modes[1], self.n_modes[2]), + compute_dtype=self.compute_dtype) + self._idft3_cell = IRDFTn(shape=(self.resolutions[0], self.resolutions[1], self.resolutions[2]), norm='ortho', + modes=(self.n_modes[0], self.n_modes[1], self.n_modes[2]), + compute_dtype=self.compute_dtype) + self._mat_x = Tensor( + shape=(1, self.out_channels, self.resolutions[0] - 2 * self.n_modes[0], self.n_modes[1], self.n_modes[2]), + dtype=self.compute_dtype, init=Zero()) + self._mat_y = Tensor( + shape=(1, self.out_channels, self.resolutions[0], self.resolutions[1] - 2 * self.n_modes[1], + self.n_modes[2]), + dtype=self.compute_dtype, init=Zero()) + self._concat = ops.Concat(-2) + + def construct(self, x: Tensor): + x_re = x + x_ft_re, x_ft_im = self._dft3_cell(x_re) + + out_ft_re1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], + self._w_re1) - self._einsum(x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], + :self.n_modes[2]], self._w_im1) + out_ft_im1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], + self._w_im1) + self._einsum(x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], + :self.n_modes[2]], self._w_re1) + out_ft_re2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], + self._w_re2) - self._einsum(x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], + :self.n_modes[2]], self._w_im2) + out_ft_im2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], + self._w_im2) + self._einsum(x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], + :self.n_modes[2]], self._w_re2) + out_ft_re3 = self._einsum(x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], + self._w_re3) - self._einsum(x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, + :self.n_modes[2]], self._w_im3) + out_ft_im3 = self._einsum(x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], + self._w_im3) + self._einsum(x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, + :self.n_modes[2]], self._w_re3) + out_ft_re4 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], + self._w_re4) - self._einsum(x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, + :self.n_modes[2]], self._w_im4) + out_ft_im4 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], + self._w_im4) + self._einsum(x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, + :self.n_modes[2]], self._w_re4) + + batch_size = x.shape[0] + mat_x = mint.repeat_interleave(self._mat_x, batch_size, 0) + mat_y = mint.repeat_interleave(self._mat_y, batch_size, 0) + + out_re1 = ops.concat((out_ft_re1, mat_x, out_ft_re2), -3) + out_im1 = ops.concat((out_ft_im1, mat_x, out_ft_im2), -3) + + out_re2 = ops.concat((out_ft_re3, mat_x, out_ft_re4), -3) + out_im2 = ops.concat((out_ft_im3, mat_x, out_ft_im4), -3) + out_re = ops.concat((out_re1, mat_y, out_re2), -2) + out_im = ops.concat((out_im1, mat_y, out_im2), -2) + x = self._idft3_cell(out_re, out_im) + + return x diff --git a/MindFlow/mindflow/cell/neural_operators/kno1d.py b/MindFlow/mindflow/cell/neural_operators/kno1d.py index b4c11bb5ddfcd5bd3d698281b2f234adc110f344..81820de728a5a6d5563e04c6819e28eb6db5bf1b 100644 --- a/MindFlow/mindflow/cell/neural_operators/kno1d.py +++ b/MindFlow/mindflow/cell/neural_operators/kno1d.py @@ -16,7 +16,7 @@ import mindspore.common.dtype as mstype from mindspore import ops, nn, Tensor -from .dft import SpectralConv1dDft +from .fno_sp import SpectralConv1dDft from ...utils.check_func import check_param_type diff --git a/MindFlow/mindflow/cell/neural_operators/kno2d.py b/MindFlow/mindflow/cell/neural_operators/kno2d.py index 07036e972f1fc419804e74406a7499ba6da71e5e..79f9ae98a2a824cb2339287ced8f7c66e3655af5 100644 --- a/MindFlow/mindflow/cell/neural_operators/kno2d.py +++ b/MindFlow/mindflow/cell/neural_operators/kno2d.py @@ -16,7 +16,7 @@ import mindspore.common.dtype as mstype from mindspore import ops, nn, Tensor -from .dft import SpectralConv2dDft +from .fno_sp import SpectralConv2dDft from ...utils.check_func import check_param_type diff --git a/MindFlow/mindflow/core/fourier.py b/MindFlow/mindflow/core/fourier.py index 17c5c64678f52470cdbdee447189327a2757a577..64f980668abda0e4b045f4645c35202cb0b24445 100644 --- a/MindFlow/mindflow/core/fourier.py +++ b/MindFlow/mindflow/core/fourier.py @@ -14,14 +14,14 @@ # ============================================================================== ''' provide complex dft based on the real dft API in mindflow.dft ''' import numpy as np -from scipy.linalg import dft +import scipy import mindspore as ms import mindspore.common.dtype as mstype from mindspore import nn, ops, Tensor, mint from mindspore.common.initializer import Zero from mindspore.ops import operations as P -from ..utils.check_func import check_param_no_greater, check_param_value, check_param_type, check_param_even +from ..utils.check_func import check_param_no_greater, check_param_value class MyRoll(nn.Cell): @@ -76,88 +76,97 @@ class MyFlip(nn.Cell): return x +def convert_shape(shape): + ''' convert shape to suitable format ''' + if isinstance(shape, int): + n = shape + elif len(shape) == 1: + n, = shape + else: + raise TypeError("Only support 1D dct/dst, but got shape {}".format(shape)) + return n + + def convert_params(shape, modes, dim): ''' convert input arguments to suitable format ''' + shape = tuple(np.atleast_1d(shape).astype(int).tolist()) + ndim = len(shape) + if dim is None: - ndim = len(shape) dim = tuple([n - ndim for n in range(ndim)]) else: dim = tuple(np.atleast_1d(dim).astype(int).tolist()) - shape = tuple(np.atleast_1d(shape).astype(int).tolist()) - modes = tuple(np.atleast_1d(modes).astype(int).tolist()) + if modes is None or isinstance(modes, int): + modes = tuple([modes] * ndim) + else: + modes = tuple(np.atleast_1d(modes).astype(int).tolist()) return shape, modes, dim def check_params(shape, modes, dim): ''' check lawfulness of input arguments ''' - check_param_type(dim, "dim", data_type=tuple) - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=tuple) check_param_no_greater(len(dim), "dim length", 3) check_param_value(len(shape), "shape length", len(dim)) check_param_value(len(modes), "modes length", len(dim)) - check_param_even(shape, "shape") - for i, (m, n) in enumerate(zip(modes, shape)): - check_param_no_greater(m, f'mode{i+1}', n // 2 + (i == len(dim) - 1)) + if np.any(modes): + for i, (m, n) in enumerate(zip(modes, shape)): + # if for last axis mode need to be n//2+1, mode should be set to None + check_param_no_greater(m, f'mode{i+1}', n // 2) class _DFT1d(nn.Cell): '''One dimensional Discrete Fourier Transformation''' - def __init__(self, n, modes, last_index, idx=0, inv=False, compute_dtype=mstype.float32): + def __init__(self, n, mode, last_index, idx=0, scale='sqrtn', inv=False, compute_dtype=mstype.float32): super().__init__() self.n = n - self.dft_mat = dft(n, scale="sqrtn") - self.modes = modes + self.dft_mat = scipy.linalg.dft(n, scale=scale) self.last_index = last_index self.inv = inv + self.odd = bool(n % 2) self.idx = idx + self.mode_upper = mode if mode else n // 2 + (self.last_index or self.odd) + self.mode_lower = mode if mode else n - self.mode_upper self.compute_dtype = compute_dtype - self.dft_mode_mat_upper = self.dft_mat[:, :modes] - self.a_re_upper = Tensor( - self.dft_mode_mat_upper.real, dtype=compute_dtype) - self.a_im_upper = Tensor( - self.dft_mode_mat_upper.imag, dtype=compute_dtype) - - self.dft_mode_mat_lower = self.dft_mat[:, -modes:] - self.a_re_lower = Tensor( - self.dft_mode_mat_lower.real, dtype=compute_dtype) - self.a_im_lower = Tensor( - self.dft_mode_mat_lower.imag, dtype=compute_dtype) + # generate DFT matrix for positive and negative frequencies + dft_mat_mode = self.dft_mat[:, :self.mode_upper] + self.a_re_upper = Tensor(dft_mat_mode.real, dtype=compute_dtype) + self.a_im_upper = Tensor(dft_mat_mode.imag, dtype=compute_dtype) + + dft_mat_mode = self.dft_mat[:, -self.mode_lower:] + self.a_re_lower = Tensor(dft_mat_mode.real, dtype=compute_dtype) + self.a_im_lower = Tensor(dft_mat_mode.imag, dtype=compute_dtype) + + # the zero matrix to fill the un-transformed modes + m = self.n - (self.mode_upper + self.mode_lower) + if m > 0: + self.mat = Tensor(shape=m, dtype=compute_dtype, init=Zero()) + self.concat = ops.Concat(axis=-1) + self.cast = P.Cast() if self.inv: self.a_re_upper = self.a_re_upper.T self.a_im_upper = -self.a_im_upper.T + self.a_re_lower = self.a_re_lower.T + self.a_im_lower = -self.a_im_lower.T + + # last axis is real-transformed, so the inverse is conjugate of the positive frequencies if last_index: - if modes == n // 2 + 1: - self.dft_mat_res = self.dft_mat[:, -modes + 2:] - else: - self.dft_mat_res = self.dft_mat[:, -modes + 1:] - - mat = Tensor(np.zeros(n,), dtype=compute_dtype).reshape(n, 1) - self.a_re_res = MyFlip()(Tensor(self.dft_mat_res.real, dtype=compute_dtype), dims=-1) - self.a_im_res = MyFlip()(Tensor(self.dft_mat_res.imag, dtype=compute_dtype), dims=-1) - if modes == n // 2 + 1: - self.a_re_res = self.concat((mat, self.a_re_res, mat)) - self.a_im_res = self.concat((mat, self.a_im_res, mat)) - else: - self.a_re_res = self.concat((mat, self.a_re_res)) - self.a_im_res = self.concat((mat, self.a_im_res)) - - self.a_re_res = self.a_re_res.T - self.a_im_res = -self.a_im_res.T - else: - self.a_re_res = self.a_re_lower.T - self.a_im_res = -self.a_im_lower.T - - if (self.n - 2 * self.modes) > 0: - self.mat = Tensor(shape=(self.n - 2 * self.modes), - dtype=compute_dtype, init=Zero()) + mode_res = min(self.mode_lower, self.mode_upper - 1) + dft_mat_res = self.dft_mat[:, -mode_res:] + a_re_res = MyFlip()(Tensor(dft_mat_res.real, dtype=compute_dtype), dims=-1) + a_im_res = MyFlip()(Tensor(dft_mat_res.imag, dtype=compute_dtype), dims=-1) + + a_re_res = ops.pad(a_re_res, (1, self.mode_upper - mode_res - 1)) + a_im_res = ops.pad(a_im_res, (1, self.mode_upper - mode_res - 1)) + + self.a_re_upper += a_re_res.T + self.a_im_upper += a_im_res.T def swap_axes(self, x_re, x_im): return x_re.swapaxes(-1, self.idx), x_im.swapaxes(-1, self.idx) @@ -168,10 +177,9 @@ class _DFT1d(nn.Cell): return y_re, y_im def zero_mat(self, dims): - length = len(dims) mat = self.mat - for i in range(length - 1, -1, -1): - mat = mint.repeat_interleave(mat.expand_dims(0), dims[i], 0) + for n in dims[::-1]: + mat = mint.repeat_interleave(mat.expand_dims(0), n, 0) return mat def compute_forward(self, x_re, x_im): @@ -185,7 +193,7 @@ class _DFT1d(nn.Cell): y_re2, y_im2 = self.complex_matmul( x_re=x_re, x_im=x_im, a_re=self.a_re_lower, a_im=self.a_im_lower) - if self.n == self.modes * 2: + if self.n == self.mode_upper + self.mode_lower: y_re = self.concat((y_re, y_re2)) y_im = self.concat((y_im, y_im2)) else: @@ -197,26 +205,23 @@ class _DFT1d(nn.Cell): def compute_inverse(self, x_re, x_im): ''' Inverse transform for irdft ''' - y_re, y_im = self.complex_matmul(x_re=x_re[..., :self.modes], - x_im=x_im[..., :self.modes], + y_re, y_im = self.complex_matmul(x_re=x_re[..., :self.mode_upper], + x_im=x_im[..., :self.mode_upper], a_re=self.a_re_upper, a_im=self.a_im_upper) if self.last_index: - y_re_res, y_im_res = self.complex_matmul(x_re=x_re, - x_im=x_im, - a_re=self.a_re_res, - a_im=-self.a_im_res) - else: - y_re_res, y_im_res = self.complex_matmul(x_re=x_re[..., -self.modes:], - x_im=x_im[..., -self.modes:], - a_re=self.a_re_res, - a_im=self.a_im_res) + return y_re, y_im + + y_re_res, y_im_res = self.complex_matmul(x_re=x_re[..., -self.mode_lower:], + x_im=x_im[..., -self.mode_lower:], + a_re=self.a_re_lower, + a_im=self.a_im_lower) return y_re + y_re_res, y_im + y_im_res def construct(self, x): ''' perform 1d rdft/irdft with matmul operations ''' x_re, x_im = x - x_re, x_im = P.Cast()(x_re, self.compute_dtype), P.Cast()(x_im, self.compute_dtype) + x_re, x_im = self.cast(x_re, self.compute_dtype), self.cast(x_im, self.compute_dtype) x_re, x_im = self.swap_axes(x_re, x_im) if self.inv: y_re, y_im = self.compute_inverse(x_re, x_im) @@ -227,52 +232,51 @@ class _DFT1d(nn.Cell): class _DFTn(nn.Cell): - r""" - N dimensional Discrete Fourier Transformation - - Args: - shape (tuple): Dimension of the input 'x'. - modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. - inv (bool): Whether to compute inverse transformation. Default: False. - compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 3-D tuple of Tensor. It's a complex, - including x real and imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is greater than 3. - """ - def __init__(self, shape, modes, dim=None, inv=False, compute_dtype=mstype.float32): + ''' Base class for n-D DFT transform ''' + def __init__(self, shape, dim=None, norm='backward', modes=None, compute_dtype=mstype.float32): super().__init__() - if dim is None: - dim = range(len(shape)) - self.dft1_seq = nn.SequentialCell() - last_index = [False for _ in range(len(shape))] - last_index[-1] = True - for dim_id, idx in enumerate(dim): - self.dft1_seq.append( - _DFT1d(n=shape[dim_id], modes=modes[dim_id], last_index=last_index[dim_id], idx=idx, inv=inv, - compute_dtype=compute_dtype)) - - def construct(self, x): - return self.dft1_seq(x) - + shape, modes, dim = convert_params(shape, modes, dim) + check_params(shape, modes, dim) -class RDFTn(nn.Cell): + ndim = len(shape) + inv, scale, r2c_flags = self.set_options(ndim, norm) + self.dft1_seq = nn.SequentialCell() + for n, m, r, d in zip(shape, modes, r2c_flags, dim): + self.dft1_seq.append(_DFT1d( + n=n, mode=m, last_index=r, idx=d, scale=scale, inv=inv, compute_dtype=compute_dtype)) + + def set_options(self, ndim, norm): + ''' + Choose the dimensions, normalization, and transformation mode (forward/backward). + Derivative APIs overwrite the options to achieve their specific goals. + ''' + inv = False + scale = { + 'backward': None, + 'forward': 'n', + 'ortho': 'sqrtn', + }[norm] + r2c_flags = np.zeros(ndim, dtype=bool).tolist() + r2c_flags[-1] = True + return inv, scale, r2c_flags + + def construct(self, *args, **kwargs): + raise NotImplementedError + + +class RDFTn(_DFTn): r""" 1/2/3D discrete real Fourier transformation on real number. The results should be same as `scipy.fft.rfftn() `_ . Args: shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. + norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward', + same as torch.fft.rfftn + modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. Inputs: @@ -296,37 +300,25 @@ class RDFTn(nn.Cell): >>> print(br.shape) (2, 32, 257) """ - def __init__(self, shape, compute_dtype=mstype.float32): - super().__init__() - - n = shape[-1] - ndim = len(shape) - modes = tuple([_ // 2 for _ in shape[-ndim:-1]] + [n // 2 + 1]) if ndim > 1 else n // 2 + 1 - - shape, modes, dim = convert_params(shape, modes, dim=None) - check_params(shape, modes, dim) - - self.n = n - self.ndim = ndim - self.shape = shape - self.scale = float(np.prod(shape) ** .5) - - self.dft_cell = _DFTn(shape, modes, dim, inv=False, compute_dtype=compute_dtype) - def construct(self, ar): ''' perform n-dimensional rDFT on real tensor ''' # n-D Fourier transform with last axis being real-transformed, output dimension (..., m, n//2+1) - br, bi = self.dft_cell((ar, ar * 0)) # the last ndim dimensions of ar must accord with shape - return br * self.scale, bi * self.scale + # the last ndim dimensions of ar must accord with shape + return self.dft1_seq((ar, ar * 0)) -class IRDFTn(nn.Cell): +class IRDFTn(_DFTn): r""" 1/2/3D discrete inverse real Fourier transformation on complex number. The results should be same as `scipy.fft.irfftn() `_ . Args: shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. + norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward', + same as torch.fft.irfftn + modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. Inputs: @@ -351,36 +343,34 @@ class IRDFTn(nn.Cell): >>> print(br.shape) (2, 32, 512) """ - def __init__(self, shape, compute_dtype=mstype.float32): - super().__init__() - - n = shape[-1] - ndim = len(shape) - modes = tuple([_ // 2 for _ in shape[-ndim:-1]] + [n // 2 + 1]) if ndim > 1 else n // 2 + 1 - - shape, modes, dim = convert_params(shape, modes, dim=None) - check_params(shape, modes, dim) - - self.n = n - self.ndim = ndim - self.shape = shape - self.scale = float(np.prod(self.shape) ** .5) - - self.idft_cell = _DFTn(shape, modes, dim, inv=True, compute_dtype=compute_dtype) + def set_options(self, ndim, norm): + inv = True + scale = { + 'forward': None, + 'backward': 'n', + 'ortho': 'sqrtn', + }[norm] + r2c_flags = np.zeros(ndim, dtype=bool).tolist() + r2c_flags[-1] = True + return inv, scale, r2c_flags def construct(self, ar, ai): ''' perform n-dimensional irDFT on complex tensor and output real tensor ''' - br, _ = self.idft_cell((ar, ai)) - return br / self.scale + return self.dft1_seq((ar, ai))[0] -class DFTn(nn.Cell): +class DFTn(_DFTn): r""" 1/2/3D discrete Fourier transformation on complex number. The results should be same as `scipy.fft.fftn() `_ . Args: shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. + norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward', + same as torch.fft.irfftn + modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. Inputs: @@ -404,89 +394,34 @@ class DFTn(nn.Cell): >>> print(br.shape) (2, 32, 512) """ - def __init__(self, shape, compute_dtype=mstype.float32): - super().__init__() - - n = shape[-1] - ndim = len(shape) - modes = tuple([_ // 2 for _ in shape[-ndim:-1]] + [n // 2 + 1]) if ndim > 1 else n // 2 + 1 - - shape, modes, dim = convert_params(shape, modes, dim=None) - check_params(shape, modes, dim) - - self.n = n - self.ndim = ndim - self.shape = shape - self.scale = float(np.prod(shape) ** .5) - - self.dft_cell = RDFTn(shape, compute_dtype) - - # use mask to assemble slices of Tensors, avoiding dynamic shape - mask_x0 = np.ones(self.n//2 + 1) - mask_xm = np.ones(self.n//2 + 1) - mask_y0 = np.ones(self.shape) - mask_z0 = np.ones(self.shape) - mask_x0[0] = 0 - mask_xm[-1] = 0 - if self.ndim > 1: - mask_y0[..., 0, :] = 0 - if self.ndim > 2: - mask_z0[..., 0, :, :] = 0 - - self.mask_x0 = Tensor(mask_x0, dtype=compute_dtype, const_arg=True) - self.mask_xm = Tensor(mask_xm, dtype=compute_dtype, const_arg=True) - self.mask_y0 = Tensor(mask_y0, dtype=compute_dtype, const_arg=True) - self.mask_z0 = Tensor(mask_z0, dtype=compute_dtype, const_arg=True) - - self.fliper = MyFlip() - self.roller = MyRoll() + def set_options(self, ndim, norm): + inv = False + scale = { + 'forward': 'n', + 'backward': None, + 'ortho': 'sqrtn', + }[norm] + r2c_flags = np.zeros(ndim, dtype=bool).tolist() + return inv, scale, r2c_flags def construct(self, ar, ai): ''' perform n-dimensional DFT on complex tensor ''' # n-D complex Fourier transform, output dimension (..., m, n) - # call dft for real & imag parts separately and then assemble - brr, bri = self.dft_cell(ar) # ar and ai must have same shape - bir, bii = self.dft_cell(ai) # the last ndim dimensions of ai must accord with shape - - n = self.n - - br_half1 = ops.pad((brr - bii) * self.mask_xm, [0, n//2 - 1]) - bi_half1 = ops.pad((bri + bir) * self.mask_xm, [0, n//2 - 1]) - - br_half2 = ops.pad((brr + bii) * self.mask_x0, [n//2 - 1, 0]) - bi_half2 = ops.pad((bir - bri) * self.mask_x0, [n//2 - 1, 0]) - br_half2 = self.roller(self.fliper(br_half2, dims=-1), n//2, dims=-1) - bi_half2 = self.roller(self.fliper(bi_half2, dims=-1), n//2, dims=-1) - - if self.ndim > 1: - br_half2_1 = br_half2 * (1 - self.mask_y0) - bi_half2_1 = bi_half2 * (1 - self.mask_y0) - br_half2_2 = br_half2 * self.mask_y0 - bi_half2_2 = bi_half2 * self.mask_y0 - br_half2 = br_half2_1 + self.roller(self.fliper(br_half2_2, dims=-2), 1, dims=-2) - bi_half2 = bi_half2_1 + self.roller(self.fliper(bi_half2_2, dims=-2), 1, dims=-2) - - if self.ndim > 2: - br_half2_1 = br_half2 * (1 - self.mask_z0) - bi_half2_1 = bi_half2 * (1 - self.mask_z0) - br_half2_2 = br_half2 * self.mask_z0 - bi_half2_2 = bi_half2 * self.mask_z0 - br_half2 = br_half2_1 + self.roller(self.fliper(br_half2_2, dims=-3), 1, dims=-3) - bi_half2 = bi_half2_1 + self.roller(self.fliper(bi_half2_2, dims=-3), 1, dims=-3) - - br = br_half1 + br_half2 - bi = bi_half1 + bi_half2 - - return br, bi + return self.dft1_seq((ar, ai)) -class IDFTn(nn.Cell): +class IDFTn(DFTn): r""" 1/2/3D discrete inverse Fourier transformation on complex number. The results should be same as `scipy.fft.ifftn() `_ . Args: shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. + norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward', + same as torch.fft.irfftn + modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. Inputs: @@ -510,15 +445,15 @@ class IDFTn(nn.Cell): >>> print(br.shape) (2, 32, 512) """ - def __init__(self, shape, compute_dtype=mstype.float32): - super().__init__() - self.dft_cell = DFTn(shape, compute_dtype) - - def construct(self, ar, ai): - ''' perform n-dimensional iDFT on complex tensor ''' - scale = self.dft_cell.scale**2 - br, bi = self.dft_cell(ar, -ai) - return br / scale, -bi / scale + def set_options(self, ndim, norm): + inv = True + scale = { + 'forward': None, + 'backward': 'n', + 'ortho': 'sqrtn', + }[norm] + r2c_flags = np.zeros(ndim, dtype=bool).tolist() + return inv, scale, r2c_flags class DCT(nn.Cell): @@ -552,9 +487,11 @@ class DCT(nn.Cell): """ def __init__(self, shape, compute_dtype=mstype.float32): super().__init__() - self.dft_cell = DFTn(shape, compute_dtype) - assert self.dft_cell.ndim == 1, 'only support 1D dct' - n, = self.dft_cell.shape + + n = convert_shape(shape) + + self.dft_cell = DFTn(n, compute_dtype=compute_dtype) + w = Tensor(np.arange(n) * np.pi / (2 * n), dtype=compute_dtype) self.cosw = ops.cos(w) self.sinw = ops.sin(w) @@ -603,12 +540,13 @@ class IDCT(nn.Cell): def __init__(self, shape, compute_dtype=mstype.float32): super().__init__() - self.dft_cell = IRDFTn(shape, compute_dtype) - assert self.dft_cell.ndim == 1, 'only support 1D dct' - n, = self.dft_cell.shape - assert n % 2 == 0, 'only support even length' # n has to be even, or IRDFTn would fail + n = convert_shape(shape) + + # assert n % 2 == 0, 'only support even length' # n has to be even, or IRDFTn would fail + + self.dft_cell = IRDFTn(n, compute_dtype=compute_dtype) - w = Tensor(np.arange(0, n // 2 + 1, 1) * np.pi / (2 * n), dtype=compute_dtype) + w = Tensor(np.arange(n // 2 + 1) * np.pi / (2 * n), dtype=compute_dtype) self.cosw = ops.cos(w) self.sinw = ops.sin(w) @@ -628,6 +566,9 @@ class IDCT(nn.Cell): c2 = self.fliper(c[..., (n + 1) // 2:], dims=-1) d1 = ops.pad(c1.reshape(-1)[..., None], (0, 1)).reshape(*c1.shape[:-1], -1) d2 = ops.pad(c2.reshape(-1)[..., None], (1, 0)).reshape(*c2.shape[:-1], -1) + # in case n is odd, d1 and d2 need to be aligned + d1 = d1[..., :n] + d2 = ops.pad(d2, (0, n % 2)) return d1 + d2 @@ -662,8 +603,9 @@ class DST(nn.Cell): """ def __init__(self, shape, compute_dtype=mstype.float32): super().__init__() - self.dft_cell = DCT(shape, compute_dtype) - multiplier = np.ones(shape) + n = convert_shape(shape) + self.dft_cell = DCT(n, compute_dtype=compute_dtype) + multiplier = np.ones(n) multiplier[..., 1::2] *= -1 self.multiplier = Tensor(multiplier, dtype=compute_dtype) @@ -703,8 +645,9 @@ class IDST(nn.Cell): """ def __init__(self, shape, compute_dtype=mstype.float32): super().__init__() - self.dft_cell = IDCT(shape, compute_dtype) - multiplier = np.ones(shape) + n = convert_shape(shape) + self.dft_cell = IDCT(n, compute_dtype=compute_dtype) + multiplier = np.ones(n) multiplier[..., 1::2] *= -1 self.multiplier = Tensor(multiplier, dtype=compute_dtype) diff --git a/tests/st/mindflow/networks/fno/test_fno.py b/tests/st/mindflow/networks/fno/test_fno.py index 10e9ad45c0f5c243b97661ac9b3f958709a7788f..7fbc0c437d57b553a2a592c494dadcc3ceccb3f8 100644 --- a/tests/st/mindflow/networks/fno/test_fno.py +++ b/tests/st/mindflow/networks/fno/test_fno.py @@ -20,7 +20,7 @@ import numpy as np from mindspore import Tensor, context, set_seed, load_param_into_net, load_checkpoint from mindspore import dtype as mstype from mindflow.cell import FNO1D, FNO2D, FNO3D -from mindflow.cell.neural_operators.dft import SpectralConv1dDft, SpectralConv2dDft, SpectralConv3dDft +from mindflow.cell.neural_operators.fno_sp import SpectralConv1dDft, SpectralConv2dDft, SpectralConv3dDft RTOL = 0.001 set_seed(123456) diff --git a/tests/st/mindflow/operators/test_fourier.py b/tests/st/mindflow/operators/test_fourier.py index 44f354cc8639c894008407dec472e14ee84ef6da..188e247206e784ac3910b38d4c19b5f27ac98513 100644 --- a/tests/st/mindflow/operators/test_fourier.py +++ b/tests/st/mindflow/operators/test_fourier.py @@ -30,7 +30,7 @@ sys.path.append(PROJECT_ROOT) # pylint: disable=wrong-import-position -from common.cell import FP32_RTOL +from common.cell import FP32_RTOL, FP16_RTOL, FP32_ATOL, FP16_ATOL from common.cell.utils import compare_output # pylint: enable=wrong-import-position @@ -57,7 +57,8 @@ def gen_input(shape=(5, 6, 4, 8), rand_test=True): @pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) @pytest.mark.parametrize('ndim', [1, 2, 3]) -def test_rdft_accuracy(device_target, mode, ndim): +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_rdft_accuracy(device_target, mode, ndim, compute_dtype): """ Feature: Test RDFTn & IRDFTn accuracy Description: Input random tensor, compare the results of RDFTn and IRDFTn with numpy results @@ -68,12 +69,15 @@ def test_rdft_accuracy(device_target, mode, ndim): shape = a.shape b = np.fft.rfftn(a.real, s=a.shape[-ndim:], axes=range(-ndim, 0)) - br, bi = RDFTn(shape[-ndim:])(ar) - cr = IRDFTn(shape[-ndim:])(br, bi) + br, bi = RDFTn(shape[-ndim:], compute_dtype=compute_dtype)(ar) + cr = IRDFTn(shape[-ndim:], compute_dtype=compute_dtype)(br, bi) - assert compare_output(b.real, br.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(b)) - assert compare_output(b.imag, bi.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(b)) - assert compare_output(a.real, cr.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(a)) + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20 + + assert compare_output(br.numpy(), b.real, rtol, atol) + assert compare_output(bi.numpy(), b.imag, rtol, atol) + assert compare_output(cr.numpy(), a.real, rtol, atol) @pytest.mark.level0 @@ -82,7 +86,8 @@ def test_rdft_accuracy(device_target, mode, ndim): @pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) @pytest.mark.parametrize('ndim', [1, 2, 3]) -def test_dft_accuracy(device_target, mode, ndim): +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_dft_accuracy(device_target, mode, ndim, compute_dtype): """ Feature: Test DFTn & IDFTn accuracy Description: Input random tensor, compare the results of DFTn and IDFTn with numpy results @@ -93,13 +98,16 @@ def test_dft_accuracy(device_target, mode, ndim): shape = a.shape b = np.fft.fftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0)) - br, bi = DFTn(shape[-ndim:])(ar, ai) - cr, ci = IDFTn(shape[-ndim:])(br, bi) + br, bi = DFTn(shape[-ndim:], compute_dtype=compute_dtype)(ar, ai) + cr, ci = IDFTn(shape[-ndim:], compute_dtype=compute_dtype)(br, bi) + + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20 - assert compare_output(b.real, br.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(b)) - assert compare_output(b.imag, bi.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(b)) - assert compare_output(a.real, cr.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(a)) - assert compare_output(a.imag, ci.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(a)) + assert compare_output(br.numpy(), b.real, rtol, atol) + assert compare_output(bi.numpy(), b.imag, rtol, atol) + assert compare_output(cr.numpy(), a.real, rtol, atol) + assert compare_output(ci.numpy(), a.imag, rtol, atol) @pytest.mark.level0 @@ -107,7 +115,8 @@ def test_dft_accuracy(device_target, mode, ndim): @pytest.mark.env_onecard @pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_dct_accuracy(device_target, mode): +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_dct_accuracy(device_target, mode, compute_dtype): """ Feature: Test DCT & IDCT accuracy Description: Input random tensor, compare the results of DCT and IDCT with numpy results @@ -118,11 +127,14 @@ def test_dct_accuracy(device_target, mode): shape = a.shape b = dct(a.real) - br = DCT(shape[-1:])(ar) - cr = IDCT(shape[-1:])(br) + br = DCT(shape[-1:], compute_dtype=compute_dtype)(ar) + cr = IDCT(shape[-1:], compute_dtype=compute_dtype)(br) + + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20 - assert compare_output(b.real, br.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(b)) - assert compare_output(a.real, cr.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(a)) + assert compare_output(br.numpy(), b.real, rtol, atol) + assert compare_output(cr.numpy(), a.real, rtol, atol) @pytest.mark.level0 @@ -130,7 +142,8 @@ def test_dct_accuracy(device_target, mode): @pytest.mark.env_onecard @pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_dst_accuracy(device_target, mode): +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_dst_accuracy(device_target, mode, compute_dtype): """ Feature: Test DST & IDST accuracy Description: Input random tensor, compare the results of DST and IDST with numpy results @@ -141,11 +154,14 @@ def test_dst_accuracy(device_target, mode): shape = a.shape b = dst(a.real) - br = DST(shape[-1:])(ar) - cr = IDST(shape[-1:])(br) + br = DST(shape[-1:], compute_dtype=compute_dtype)(ar) + cr = IDST(shape[-1:], compute_dtype=compute_dtype)(br) - assert compare_output(b.real, br.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(b)) - assert compare_output(a.real, cr.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(a)) + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20 + + assert compare_output(br.numpy(), b.real, rtol, atol) + assert compare_output(cr.numpy(), a.real, rtol, atol) @pytest.mark.level0 @@ -201,7 +217,8 @@ def test_dft_speed(device_target, mode, ndim): @pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) @pytest.mark.parametrize('ndim', [1, 2, 3]) -def test_dft_grad(device_target, mode, ndim): +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_dft_grad(device_target, mode, ndim, compute_dtype): """ Feature: Test the correctness of DFTn & IDFTn grad calculation Description: Input random tensor, compare the autograd results with theoretic solutions @@ -211,7 +228,7 @@ def test_dft_grad(device_target, mode, ndim): a, ar, ai = gen_input() shape = a.shape - dft_cell = DFTn(shape[-ndim:]) + dft_cell = DFTn(shape[-ndim:], compute_dtype=compute_dtype) def forward_fn(xr, xi): yr, yi = dft_cell(xr, xi) @@ -224,5 +241,8 @@ def test_dft_grad(device_target, mode, ndim): b = np.fft.fftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0)) g = np.fft.ifftn(b, s=a.shape[-ndim:], axes=range(-ndim, 0)) * 2 * np.prod(a.shape[-ndim:]) - assert compare_output(g.real, g1.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(g)) - assert compare_output(g.imag, g2.numpy(), rtol=FP32_RTOL, atol=FP32_RTOL * np.linalg.norm(g)) + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 500 # grad func leads to larger error + + assert compare_output(g1.numpy(), g.real, rtol, atol) + assert compare_output(g2.numpy(), g.imag, rtol, atol)