From eb13efb576038b0d6d1e3a5a16e8ece8ee76dc4f Mon Sep 17 00:00:00 2001 From: goto Date: Tue, 5 Aug 2025 08:44:49 +0800 Subject: [PATCH 1/2] delete neural_operators.dft --- .../mindflow/cell/neural_operators/dft.py | 723 ------------------ .../st/mindflow/cell/dft/data/dft1_input.npy | Bin 176 -> 0 bytes .../st/mindflow/cell/dft/data/dft1_output.npy | Bin 136 -> 0 bytes .../st/mindflow/cell/dft/data/idft1_input.npy | Bin 160 -> 0 bytes .../mindflow/cell/dft/data/idft1_output.npy | Bin 160 -> 0 bytes tests/st/mindflow/cell/test_dft.py | 89 --- tests/st/mindflow/operators/test_dft.py | 127 --- 7 files changed, 939 deletions(-) delete mode 100644 MindFlow/mindflow/cell/neural_operators/dft.py delete mode 100644 tests/st/mindflow/cell/dft/data/dft1_input.npy delete mode 100644 tests/st/mindflow/cell/dft/data/dft1_output.npy delete mode 100644 tests/st/mindflow/cell/dft/data/idft1_input.npy delete mode 100644 tests/st/mindflow/cell/dft/data/idft1_output.npy delete mode 100644 tests/st/mindflow/cell/test_dft.py delete mode 100644 tests/st/mindflow/operators/test_dft.py diff --git a/MindFlow/mindflow/cell/neural_operators/dft.py b/MindFlow/mindflow/cell/neural_operators/dft.py deleted file mode 100644 index e41ba2b49..000000000 --- a/MindFlow/mindflow/cell/neural_operators/dft.py +++ /dev/null @@ -1,723 +0,0 @@ -'''' -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -''' -import numpy as np -from scipy.linalg import dft - -import mindspore -import mindspore.common.dtype as mstype -from mindspore import nn, ops, Tensor, Parameter, mint -from mindspore.common.initializer import Zero -from mindspore.ops import operations as P - -from ...utils.check_func import check_param_no_greater, check_param_value, check_param_type, check_param_even - - -class DFT1d(nn.Cell): - '''One dimensional Discrete Fourier Transformation''' - - def __init__(self, n, modes, last_index, idx=0, inv=False, compute_dtype=mindspore.float32): - super().__init__() - - self.n = n - self.dft_mat = dft(n, scale="sqrtn") - self.modes = modes - self.last_index = last_index - self.inv = inv - self.idx = idx - self.compute_dtype = compute_dtype - - self.dft_mode_mat_upper = self.dft_mat[:, :modes] - self.a_re_upper = Tensor( - self.dft_mode_mat_upper.real, dtype=compute_dtype) - self.a_im_upper = Tensor( - self.dft_mode_mat_upper.imag, dtype=compute_dtype) - - self.dft_mode_mat_lower = self.dft_mat[:, -modes:] - self.a_re_lower = Tensor( - self.dft_mode_mat_lower.real, dtype=compute_dtype) - self.a_im_lower = Tensor( - self.dft_mode_mat_lower.imag, dtype=compute_dtype) - self.concat = ops.Concat(axis=-1) - - if self.inv: - self.a_re_upper = self.a_re_upper.T - self.a_im_upper = -self.a_im_upper.T - if last_index: - if modes == n // 2 + 1: - self.dft_mat_res = self.dft_mat[:, -modes + 2:] - else: - self.dft_mat_res = self.dft_mat[:, -modes + 1:] - - mat = Tensor(np.zeros(n, ), dtype=compute_dtype).reshape(n, 1) - self.a_re_res = mindspore.numpy.flip( - Tensor(self.dft_mat_res.real, dtype=compute_dtype), axis=-1) - self.a_im_res = mindspore.numpy.flip( - Tensor(self.dft_mat_res.imag, dtype=compute_dtype), axis=-1) - if modes == n // 2 + 1: - self.a_re_res = self.concat((mat, self.a_re_res, mat)) - self.a_im_res = self.concat((mat, self.a_im_res, mat)) - else: - self.a_re_res = self.concat((mat, self.a_re_res)) - self.a_im_res = self.concat((mat, self.a_im_res)) - - self.a_re_res = self.a_re_res.T - self.a_im_res = -self.a_im_res.T - else: - self.a_re_res = self.a_re_lower.T - self.a_im_res = -self.a_im_lower.T - - if (self.n - 2 * self.modes) > 0: - self.mat = Tensor(shape=(self.n - 2 * self.modes), - dtype=compute_dtype, init=Zero()) - - def swap_axes(self, x_re, x_im): - return x_re.swapaxes(-1, self.idx), x_im.swapaxes(-1, self.idx) - - def complex_matmul(self, x_re, x_im, a_re, a_im): - y_re = ops.matmul(x_re, a_re) - ops.matmul(x_im, a_im) - y_im = ops.matmul(x_im, a_re) + ops.matmul(x_re, a_im) - return y_re, y_im - - def construct(self, x): - x_re, x_im = x - x_re, x_im = P.Cast()(x_re, self.compute_dtype), P.Cast()(x_im, self.compute_dtype) - if not self.inv: - x_re, x_im = self.swap_axes(x_re, x_im) - y_re, y_im = self.complex_matmul( - x_re=x_re, x_im=x_im, a_re=self.a_re_upper, a_im=self.a_im_upper) - - if not self.last_index: - y_re2, y_im2 = self.complex_matmul( - x_re=x_re, x_im=x_im, a_re=self.a_re_lower, a_im=self.a_im_lower) - - if self.n == self.modes * 2: - y_re = self.concat((y_re, y_re2)) - y_im = self.concat((y_im, y_im2)) - else: - dims = x_re.shape[:-1] - length = len(dims) - mat = self.mat - for i in range(length - 1, -1, -1): - mat = mint.repeat_interleave(mat.expand_dims(0), dims[i], 0) - y_re = self.concat((y_re, mat, y_re2)) - y_im = self.concat((y_im, mat, y_im2)) - - y_re, y_im = self.swap_axes(y_re, y_im) - return y_re, y_im - - x_re, x_im = self.swap_axes(x_re, x_im) - y_re, y_im = self.complex_matmul(x_re=x_re[..., :self.modes], x_im=x_im[..., :self.modes], - a_re=self.a_re_upper, - a_im=self.a_im_upper) - y_re, y_im = self.swap_axes(y_re, y_im) - - if self.last_index: - y_re_res, y_im_res = self.complex_matmul( - x_re=x_re, x_im=x_im, a_re=self.a_re_res, a_im=-self.a_im_res) - else: - y_re_res, y_im_res = self.complex_matmul(x_re=x_re[..., -self.modes:], x_im=x_im[..., -self.modes:], - a_re=self.a_re_res, a_im=self.a_im_res) - - y_re_res, y_im_res = self.swap_axes(y_re_res, y_im_res) - return y_re + y_re_res, y_im + y_im_res - - -class DFTn(nn.Cell): - '''N dimensional Discrete Fourier Transformation''' - - def __init__(self, shape, modes, dim=None, inv=False, compute_dtype=mindspore.float32): - super().__init__() - - if dim is None: - dim = range(len(shape)) - self.dft1_seq = nn.SequentialCell() - last_index = [False for _ in range(len(shape))] - last_index[-1] = True - for dim_id, idx in enumerate(dim): - self.dft1_seq.append( - DFT1d(n=shape[dim_id], modes=modes[dim_id], last_index=last_index[dim_id], idx=idx, inv=inv, - compute_dtype=compute_dtype)) - - def construct(self, x): - return self.dft1_seq(x) - - -def _dftn(shape, modes, dim=None, compute_dtype=mindspore.float32): - dftn_ = DFTn(shape=shape, modes=modes, dim=dim, - inv=False, compute_dtype=compute_dtype) - return dftn_ - - -def _idftn(shape, modes, dim=None, compute_dtype=mindspore.float32): - idftn_ = DFTn(shape=shape, modes=modes, dim=dim, - inv=True, compute_dtype=compute_dtype) - return idftn_ - - -def dft3(shape, modes, dim=(-3, -2, -1), compute_dtype=mindspore.float32): - r""" - Calculate three-dimensional discrete Fourier transform. Corresponding to the rfftn operator in torch. - - Args: - shape (tuple): Dimension of the input 'x'. - modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. - compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 3-D tuple of Tensor. It's a complex, - including x real and imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is no equal to 3. - - Examples: - >>> import numpy as np - >>> from mindspore import Tensor, ops - >>> import mindspore.common.dtype as mstype - >>> from mindflow.cell.neural_operators.dft import dft3 - >>> array = np.ones((6, 6, 6)) * np.arange(1, 7) - >>> x_re = Tensor(array, dtype=mstype.float32) - >>> x_im = x_re - >>> dft3_cell = dft3(shape=array.shape, modes=(2, 2, 2), compute_dtype=mstype.float32) - >>> ret, _ = dft3_cell((x_re, x_im)) - >>> print(ret) - [[[ 5.1439293e+01 -2.0076393e+01] - [ 7.9796671e-08 -1.9494735e-08] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 9.0537789e-08 1.0553553e-07] - [ 3.3567730e-07 1.0368046e-07]] - - [[ 4.7683722e-07 -3.1770034e-07] - [ 6.5267522e-15 -2.7775875e-15] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [-2.1755840e-15 -1.5215135e-15] - [ 3.6259736e-15 -4.0336615e-15]] - - [[ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00]] - - [[ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00]] - - [[ 1.1920930e-07 -5.1619136e-08] - [-3.6259733e-16 -1.0747753e-15] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 3.6259733e-16 -1.8129867e-16] - [ 3.6259733e-16 -1.4373726e-15]] - - [[ 5.9604650e-07 -2.5809570e-07] - [ 8.7023360e-15 -1.9812689e-15] - [ 0.0000000e+00 0.0000000e+00] - [ 0.0000000e+00 0.0000000e+00] - [ 2.9007787e-15 7.2519467e-16] - [ 8.7023360e-15 -1.7869532e-15]]] - - """ - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=tuple) - check_param_value(len(shape), "shape length", 3) - check_param_value(len(modes), "modes length", 3) - check_param_even(shape, "shape") - check_param_no_greater(modes[0], "mode1", shape[0] // 2) - check_param_no_greater(modes[1], "mode2", shape[1] // 2) - check_param_no_greater(modes[2], "mode3", shape[2] // 2 + 1) - return _dftn(shape, modes, dim=dim, compute_dtype=compute_dtype) - - -def idft3(shape, modes, dim=(-3, -2, -1), compute_dtype=mindspore.float32): - r""" - Calculate three-dimensional discrete Fourier transform. Corresponding to the irfftn operator in torch. - - Args: - shape (tuple): Dimension of the input 'x'. - modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. - compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 3-D tuple of Tensor. It's a complex, including x real and - imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is no equal to 3. - - Examples: - >>> import numpy as np - >>> from mindspore import Tensor, ops - >>> import mindspore.common.dtype as mstype - >>> from mindflow.cell.neural_operators.dft import idft3 - >>> array = np.ones((2, 2, 2)) * np.arange(1, 3) - >>> x_re = Tensor(array, dtype=mstype.float32) - >>> x_im = ops.zeros_like(x_re) - >>> idft3_cell = idft3(shape=(6, 6, 6), modes=(2, 2, 2), compute_dtype=mstype.float32) - >>> ret, _ = idft3_cell((x_re, x_im)) - >>> print(ret) - [[[ 5.44331074e+00 3.26598644e+00 -1.08866215e+00 -3.26598644e+00 -1.08866215e+00 3.26598644e+00] - [ 2.04124165e+00 2.04124165e+00 4.08248246e-01 -1.22474492e+00 -1.22474492e+00 4.08248365e-01] - [-6.80413842e-01 -1.22474492e+00 -6.80413783e-01 4.08248305e-01 9.52579379e-01 4.08248246e-01] - [ 0.00000000e+00 -2.30921616e-16 -2.30921616e-16 6.53092730e-32 2.30921616e-16 2.30921616e-16] - [-6.80413842e-01 4.08248246e-01 9.52579379e-01 4.08248305e-01 -6.80413783e-01 -1.22474492e+00] - [ 2.04124165e+00 4.08248365e-01 -1.22474492e+00 -1.22474492e+00 4.08248246e-01 2.04124165e+00]] - ...... - [[ 2.04124165e+00 4.08248544e-01 -1.22474492e+00 -1.22474504e+00 4.08248186e-01 2.04124165e+00] - [ 1.02062082e+00 6.12372518e-01 -2.04124182e-01 -6.12372518e-01 -2.04124182e-01 6.12372518e-01] - [-5.10310411e-01 -5.10310411e-01 -1.02062061e-01 3.06186229e-01 3.06186229e-01 -1.02062091e-01] - [-7.21630050e-17 -1.29893429e-16 -7.21630183e-17 4.32978030e-17 1.01028220e-16 4.32978163e-17] - [-6.08337416e-08 4.08248246e-01 4.08248305e-01 3.65002428e-08 -4.08248246e-01 -4.08248305e-01] - [ 5.10310471e-01 -3.06186140e-01 -7.14434564e-01 -3.06186318e-01 5.10310352e-01 9.18558717e-01]]] - - """ - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=tuple) - check_param_value(len(shape), "shape length", 3) - check_param_value(len(modes), "modes length", 3) - check_param_even(shape, "shape") - check_param_no_greater(modes[0], "mode1", shape[0] // 2) - check_param_no_greater(modes[1], "mode2", shape[1] // 2) - check_param_no_greater(modes[2], "mode3", shape[2] // 2 + 1) - return _idftn(shape, modes, dim=dim, compute_dtype=compute_dtype) - - -def dft2(shape, modes, dim=(-2, -1), compute_dtype=mindspore.float32): - """ - Calculate two-dimensional discrete Fourier transform. Corresponding to the rfft2 operator in torch. - - Args: - shape (tuple): Dimension of the input 'x'. - modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. - compute_dtype (:class:`mindspore.dtype`): The type of input tensor. Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 2-D tuple of Tensor. It's a complex, - including x real and imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is no equal to 2. - - Examples: - >>> import numpy as np - >>> from mindspore import Tensor, ops - >>> import mindspore.common.dtype as mstype - >>> from mindflow.cell.neural_operators.dft import dft2 - >>> array = np.ones((5, 5)) * np.arange(1, 6) - >>> x_re = Tensor(array, dtype=mstype.float32) - >>> x_im = x_re - >>> dft2_cell = dft2(shape=array.shape, modes=(2, 2), compute_dtype=mstype.float32) - >>> ret, _ = dft2_cell((x_re, x_im)) - >>> print(ret) - [[ 1.5000000e+01 -5.9409552e+00] - [-2.4656805e-07 7.6130398e-08] - [ 0.0000000e+00 0.0000000e+00] - [-1.9992007e-07 7.3572544e-08] - [-2.4656805e-07 7.6130398e-08]] - - """ - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=tuple) - check_param_value(len(shape), "shape length", 2) - check_param_value(len(modes), "modes length", 2) - check_param_even(shape, "shape") - check_param_no_greater(modes[0], "mode1", shape[0] // 2) - check_param_no_greater(modes[1], "mode2", shape[1] // 2 + 1) - return _dftn(shape, modes, dim=dim, compute_dtype=compute_dtype) - - -def idft2(shape, modes, dim=(-2, -1), compute_dtype=mindspore.float32): - """ - Calculate two-dimensional discrete Fourier transform. Corresponding to the irfft2 operator in torch. - - Args: - shape (tuple): Dimension of the input 'x'. - modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. - compute_dtype (:class:`mindspore.dtype`): The type of input tensor. Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 2-D tuple of Tensor. It's a complex, - including x real and imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is no equal to 2. - - Examples: - >>> import numpy as np - >>> from mindspore import Tensor, ops - >>> import mindspore.common.dtype as mstype - >>> from mindflow.cell.neural_operators.dft import idft2 - >>> array = np.ones((2, 2)) * np.arange(1, 3) - >>> x_re = Tensor(array, dtype=mstype.float32) - >>> x_im = ops.zeros_like(x_re) - >>> idft2_cell = idft2(shape=(5, 5), modes=(2, 2), compute_dtype=mstype.float32) - >>> ret, _ = idft2_cell((x_re, x_im)) - >>> print(ret) - [[ 3.9999998 1.7888544 -1.7888546 -1.7888546 1.7888544 ] - [ 0.80901694 0.80901694 -0.08541022 -0.6381966 -0.08541021] - [-0.30901706 -0.8618034 -0.30901694 0.5854102 0.5854101 ] - [-0.30901706 0.5854101 0.5854102 -0.30901694 -0.8618034 ] - [ 0.80901694 -0.08541021 -0.6381966 -0.08541022 0.80901694]] - - """ - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=tuple) - check_param_value(len(shape), "shape length", 2) - check_param_value(len(modes), "modes length", 2) - check_param_even(shape, "shape") - check_param_no_greater(modes[0], "mode1", shape[0] // 2) - check_param_no_greater(modes[1], "mode2", shape[1] // 2 + 1) - return _idftn(shape, modes, dim=dim, compute_dtype=compute_dtype) - - -def dft1(shape, modes, dim=(-1,), compute_dtype=mindspore.float32): - """ - Calculate one-dimensional discrete Fourier transform. Corresponding to the rfft operator in torch. - - Args: - shape (tuple): Dimension of the input 'x'. - modes (int): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. - compute_dtype (:class:`mindspore.dtype`): The type of input tensor. - Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 2-D tuple of Tensor. It's a complex, - including x real and imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is no equal to 1. - - Examples: - >>> from mindspore import Tensor, ops - >>> import mindspore.common.dtype as mstype - >>> from mindflow.cell.neural_operators.dft import dft1 - >>> array = [i for i in range(5)] - >>> x_re = Tensor(array, dtype=mstype.float32) - >>> x_im = ops.zeros_like(x_re) - >>> dft1_cell = dft1(shape=(len(x_re),), modes=2, compute_dtype=mstype.float32) - >>> ret, _ = dft1_cell((x_re, x_im)) - >>> print(ret) - [ 4.4721355 -1.1180341] - - """ - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=int) - check_param_value(len(shape), "shape length", 1) - check_param_even(shape, "shape") - check_param_no_greater(modes, "mode1", shape[0] // 2 + 1) - modes = (modes,) - return _dftn(shape, modes, dim=dim, compute_dtype=compute_dtype) - - -def idft1(shape, modes, dim=(-1,), compute_dtype=mindspore.float32): - """ - Calculate one-dimensional discrete Fourier transform. Corresponding to the irfft operator in torch. - - Args: - shape (tuple): Dimension of the input 'x'. - modes (int): The length of the output transform axis. The `modes` must be no greater than half of the - dimension of input 'x'. - dim (tuple): Dimensions to be transformed. - compute_dtype (:class:`mindspore.dtype`): The type of input tensor. Default: mindspore.float32. - - Inputs: - - **x** (Tensor, Tensor): The input data. It's 2-D tuple of Tensor. It's a complex, - including x real and imaginary. Tensor of shape :math:`(*, *)`. - - Returns: - Complex tensor with the same shape of input x. - - Raises: - TypeError: If `shape` is not a tuple. - ValueError: If the length of `shape` is no equal to 1. - - Examples: - >>> from mindspore import Tensor, ops - >>> import mindspore.common.dtype as mstype - >>> from mindflow.cell.neural_operators.dft import idft1 - >>> array = [i for i in range(2)] - >>> x_re = Tensor(array, dtype=mstype.float32) - >>> x_im = x_re - >>> idft1_cell = idft1(shape=(len(x_re),), modes=2, compute_dtype=mstype.float32) - >>> ret, _ = idft1_cell((x_re, x_im)) - >>> print(ret) - [ 0.8944272 -0.5742576 -1.2493379 -0.19787574 1.127044 ] - - """ - check_param_type(shape, "shape", data_type=tuple) - check_param_type(modes, "modes", data_type=int) - check_param_value(len(shape), "shape length", 1) - check_param_even(shape, "shape") - check_param_no_greater(modes, "mode1", shape[0] // 2 + 1) - modes = (modes,) - return _idftn(shape, modes, dim=dim, compute_dtype=compute_dtype) - - -class SpectralConvDft(nn.Cell): - """Base Class for Fourier Layer, including DFT, linear transform, and Inverse DFT""" - - def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - if isinstance(n_modes, int): - n_modes = [n_modes] - self.n_modes = n_modes - if isinstance(resolutions, int): - resolutions = [resolutions] - self.resolutions = resolutions - if len(self.n_modes) != len(self.resolutions): - raise ValueError( - "The dimension of n_modes should be equal to that of resolutions, \ - but got dimension of n_modes {} and dimension of resolutions {}".format(len(self.n_modes), - len(self.resolutions))) - self.compute_dtype = compute_dtype - - def construct(self, x: Tensor): - raise NotImplementedError() - - def _einsum(self, inputs, weights): - weights = weights.expand_dims(0) - inputs = inputs.expand_dims(2) - out = inputs * weights - return out.sum(1) - - -class SpectralConv1dDft(SpectralConvDft): - """1D Fourier Layer. It does DFT, linear transform, and Inverse DFT.""" - - def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): - super().__init__(in_channels, out_channels, n_modes, resolutions) - self._scale = (1. / (self.in_channels * self.out_channels)) - w_re = Tensor(self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0]), - dtype=mstype.float32) - w_im = Tensor(self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0]), - dtype=mstype.float32) - self._w_re = Parameter(w_re, requires_grad=True) - self._w_im = Parameter(w_im, requires_grad=True) - self._dft1_cell = dft1(shape=(self.resolutions[0],), modes=self.n_modes[0], compute_dtype=self.compute_dtype) - self._idft1_cell = idft1(shape=(self.resolutions[0],), modes=self.n_modes[0], compute_dtype=self.compute_dtype) - - def construct(self, x: Tensor): - x_re = x - x_im = ops.zeros_like(x_re) - x_ft_re, x_ft_im = self._dft1_cell((x_re, x_im)) - w_re = P.Cast()(self._w_re, self.compute_dtype) - w_im = P.Cast()(self._w_im, self.compute_dtype) - out_ft_re = self._einsum(x_ft_re[:, :, :self.n_modes[0]], w_re) - self._einsum(x_ft_im[:, :, :self.n_modes[0]], - w_im) - out_ft_im = self._einsum(x_ft_re[:, :, :self.n_modes[0]], w_im) + self._einsum(x_ft_im[:, :, :self.n_modes[0]], - w_re) - - x, _ = self._idft1_cell((out_ft_re, out_ft_im)) - - return x - - -class SpectralConv2dDft(SpectralConvDft): - """2D Fourier Layer. It does DFT, linear transform, and Inverse DFT.""" - - def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): - super().__init__(in_channels, out_channels, n_modes, resolutions) - self._scale = (1. / (self.in_channels * self.out_channels)) - w_re1 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), - dtype=self.compute_dtype) - w_im1 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), - dtype=self.compute_dtype) - w_re2 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), - dtype=self.compute_dtype) - w_im2 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), - dtype=self.compute_dtype) - - self._w_re1 = Parameter(w_re1, requires_grad=True) - self._w_im1 = Parameter(w_im1, requires_grad=True) - self._w_re2 = Parameter(w_re2, requires_grad=True) - self._w_im2 = Parameter(w_im2, requires_grad=True) - - self._dft2_cell = dft2(shape=(self.resolutions[0], self.resolutions[1]), - modes=(self.n_modes[0], self.n_modes[1]), compute_dtype=self.compute_dtype) - self._idft2_cell = idft2(shape=(self.resolutions[0], self.resolutions[1]), - modes=(self.n_modes[0], self.n_modes[1]), compute_dtype=self.compute_dtype) - self._mat = Tensor(shape=(1, self.out_channels, self.resolutions[1] - 2 * self.n_modes[0], self.n_modes[1]), - dtype=self.compute_dtype, init=Zero()) - self._concat = ops.Concat(-2) - - def construct(self, x: Tensor): - x_re = x - x_im = ops.zeros_like(x_re) - x_ft_re, x_ft_im = self._dft2_cell((x_re, x_im)) - - out_ft_re1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1) - self._einsum( - x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1) - out_ft_im1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1) + self._einsum( - x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1) - - out_ft_re2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2) - self._einsum( - x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2) - out_ft_im2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2) + self._einsum( - x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2) - - batch_size = x.shape[0] - mat = mint.repeat_interleave(self._mat, batch_size, 0) - out_re = self._concat((out_ft_re1, mat, out_ft_re2)) - out_im = self._concat((out_ft_im1, mat, out_ft_im2)) - - x, _ = self._idft2_cell((out_re, out_im)) - - return x - - -class SpectralConv3dDft(SpectralConvDft): - """3D Fourier layer. It does DFT, linear transform, and Inverse DFT.""" - - def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): - super().__init__(in_channels, out_channels, n_modes, resolutions) - self._scale = (1 / (self.in_channels * self.out_channels)) - - w_re1 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_im1 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_re2 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_im2 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_re3 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_im3 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_re4 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - w_im4 = Tensor( - self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], - self.n_modes[2]), dtype=self.compute_dtype) - - self._w_re1 = Parameter(w_re1, requires_grad=True) - self._w_im1 = Parameter(w_im1, requires_grad=True) - self._w_re2 = Parameter(w_re2, requires_grad=True) - self._w_im2 = Parameter(w_im2, requires_grad=True) - self._w_re3 = Parameter(w_re3, requires_grad=True) - self._w_im3 = Parameter(w_im3, requires_grad=True) - self._w_re4 = Parameter(w_re4, requires_grad=True) - self._w_im4 = Parameter(w_im4, requires_grad=True) - - self._dft3_cell = dft3(shape=(self.resolutions[0], self.resolutions[1], self.resolutions[2]), - modes=(self.n_modes[0], self.n_modes[1], self.n_modes[2]), - compute_dtype=self.compute_dtype) - self._idft3_cell = idft3(shape=(self.resolutions[0], self.resolutions[1], self.resolutions[2]), - modes=(self.n_modes[0], self.n_modes[1], self.n_modes[2]), - compute_dtype=self.compute_dtype) - self._mat_x = Tensor( - shape=(1, self.out_channels, self.resolutions[0] - 2 * self.n_modes[0], self.n_modes[1], self.n_modes[2]), - dtype=self.compute_dtype, init=Zero()) - self._mat_y = Tensor( - shape=(1, self.out_channels, self.resolutions[0], self.resolutions[1] - 2 * self.n_modes[1], - self.n_modes[2]), - dtype=self.compute_dtype, init=Zero()) - self._concat = ops.Concat(-2) - - def construct(self, x: Tensor): - x_re = x - x_im = ops.zeros_like(x_re) - x_ft_re, x_ft_im = self._dft3_cell((x_re, x_im)) - - out_ft_re1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], - self._w_re1) - self._einsum(x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], - :self.n_modes[2]], self._w_im1) - out_ft_im1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], - self._w_im1) + self._einsum(x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], - :self.n_modes[2]], self._w_re1) - out_ft_re2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], - self._w_re2) - self._einsum(x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], - :self.n_modes[2]], self._w_im2) - out_ft_im2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], - self._w_im2) + self._einsum(x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], - :self.n_modes[2]], self._w_re2) - out_ft_re3 = self._einsum(x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], - self._w_re3) - self._einsum(x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, - :self.n_modes[2]], self._w_im3) - out_ft_im3 = self._einsum(x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], - self._w_im3) + self._einsum(x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, - :self.n_modes[2]], self._w_re3) - out_ft_re4 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], - self._w_re4) - self._einsum(x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, - :self.n_modes[2]], self._w_im4) - out_ft_im4 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], - self._w_im4) + self._einsum(x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, - :self.n_modes[2]], self._w_re4) - - batch_size = x.shape[0] - mat_x = mint.repeat_interleave(self._mat_x, batch_size, 0) - mat_y = mint.repeat_interleave(self._mat_y, batch_size, 0) - - out_re1 = ops.concat((out_ft_re1, mat_x, out_ft_re2), -3) - out_im1 = ops.concat((out_ft_im1, mat_x, out_ft_im2), -3) - - out_re2 = ops.concat((out_ft_re3, mat_x, out_ft_re4), -3) - out_im2 = ops.concat((out_ft_im3, mat_x, out_ft_im4), -3) - out_re = ops.concat((out_re1, mat_y, out_re2), -2) - out_im = ops.concat((out_im1, mat_y, out_im2), -2) - x, _ = self._idft3_cell((out_re, out_im)) - - return x diff --git a/tests/st/mindflow/cell/dft/data/dft1_input.npy b/tests/st/mindflow/cell/dft/data/dft1_input.npy deleted file mode 100644 index 1b52bf4ea655eb0162bcf87fbf56dbe54d8c1f38..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 176 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlWC!@qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= lXCxM+0{I$dI+{8PwF(pfE(R!IgwjkqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= eXCxM+0{I$7I+{8PwF(pfuHLVQ9SV-j*$)6jTp(=# diff --git a/tests/st/mindflow/cell/dft/data/idft1_input.npy b/tests/st/mindflow/cell/dft/data/idft1_input.npy deleted file mode 100644 index 20a800f355b94b0a3643f24e85dac84d99f57097..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 160 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+i=qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= eXCxM+0{I$7ItoUbItsN4WCJb+F!*4Piv|D(Js>Lp diff --git a/tests/st/mindflow/cell/dft/data/idft1_output.npy b/tests/st/mindflow/cell/dft/data/idft1_output.npy deleted file mode 100644 index 67c37f23a945888f45c41e3e2549ac7b1a5413ed..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 160 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= lXCxM+0{I$7ItnJ5ItsN4WCJb+h6DBt3=I1r7{o^90{}1;B4q#o diff --git a/tests/st/mindflow/cell/test_dft.py b/tests/st/mindflow/cell/test_dft.py deleted file mode 100644 index 6034c47d7..000000000 --- a/tests/st/mindflow/cell/test_dft.py +++ /dev/null @@ -1,89 +0,0 @@ -# ============================================================================ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""DFT Test Case""" -import os -import random -import sys - -import pytest -import numpy as np - -import mindspore as ms -from mindspore import Tensor, ops, set_seed -from mindspore import dtype as mstype -from mindflow.cell.neural_operators.dft import dft1, idft1 - -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) -sys.path.append(PROJECT_ROOT) - -from common.cell import compare_output -from common.cell import FP32_RTOL, FP32_ATOL - -set_seed(0) -np.random.seed(0) -random.seed(0) - - -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend910b_training -@pytest.mark.env_onecard -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_dft1_forward_accuracy(mode): - """ - Feature: DFT1 forward accuracy test - Description: Test the accuracy of the DFT1 operation in both GRAPH_MODE and PYNATIVE_MODE - with input data loaded from './mindflow/cell/dft/data/dft1_input.npy'. - The expected output is compared to a reference output stored in - './mindflow/cell/dft/data/dft1_output.npy'. - Expectation: The output should match the target data within the defined relative and absolute tolerance, - ensuring the DFT1 forward computation is accurate. - """ - ms.set_context(mode=mode) - input_data = np.load('./dft/data/dft1_input.npy') - x_re = Tensor(input_data, dtype=mstype.float32) - x_im = ops.zeros_like(x_re) - dft1_cell = dft1(shape=(len(x_re),), modes=2, compute_dtype=mstype.float32) - ret, _ = dft1_cell((x_re, x_im)) - ret = ret.asnumpy() - output_target = np.load('./dft/data/dft1_output.npy') - validate_ans = compare_output(ret, output_target, rtol=FP32_RTOL, atol=FP32_ATOL) - assert validate_ans, "The verification of dft1 forward accuracy is not successful." - - -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend910b_training -@pytest.mark.env_onecard -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_idft1_forward_accuracy(mode): - """ - Feature: IDFT1 forward accuracy test - Description: Test the accuracy of the IDFT1 operation in both GRAPH_MODE and PYNATIVE_MODE - with input data loaded from './mindflow/cell/dft/data/idft1_input.npy'. - The expected output is compared to a reference output stored in - './mindflow/cell/dft/data/idft1_output.npy'. - Expectation: The output should match the target data within the defined relative and absolute tolerance, - ensuring the IDFT1 forward computation is accurate. - """ - ms.set_context(mode=mode) - input_data = np.load('./dft/data/idft1_input.npy') - x_re = Tensor(input_data, dtype=mstype.float32) - x_im = x_re - idft1_cell = idft1(shape=(4,), modes=2, compute_dtype=mstype.float32) - ret, _ = idft1_cell((x_re, x_im)) - ret = ret.asnumpy() - output_target = np.load('./dft/data/idft1_output.npy') - validate_ans = compare_output(ret, output_target, rtol=FP32_RTOL, atol=FP32_ATOL) - assert validate_ans, "The verification of dft1 forward accuracy is not successful." diff --git a/tests/st/mindflow/operators/test_dft.py b/tests/st/mindflow/operators/test_dft.py deleted file mode 100644 index e001cee12..000000000 --- a/tests/st/mindflow/operators/test_dft.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Test mindflow dft""" - -import torch -import numpy as np -import pytest - -import mindspore as ms -from mindspore import ops -from mindflow.cell.neural_operators.dft import dft1, dft2, idft1, idft2 - - -def dft_1d_torch(x, dim=-1): - x = torch.Tensor(x) - - x_re_im = torch.fft.fft(x, dim=dim, norm="ortho") - x_re, x_im = x_re_im.real, x_re_im.imag - return x_re.numpy(), x_im.numpy() - - -def dft_2d_torch(x, dim=-1): - x = torch.Tensor(x) - - x_re_im = torch.fft.rfft2(x, dim=dim, norm="ortho") - x_re, x_im = x_re_im.real, x_re_im.imag - return x_re.numpy(), x_im.numpy() - - -def idft_1d_torch(x_re, x_im, dim=-1): - x = torch.stack([torch.Tensor(x_re), torch.Tensor(x_im)], dim=-1) - x = torch.view_as_complex(x) - x = torch.fft.ifft(x, norm="ortho", dim=dim) - return x.numpy() - - -def idft_2d_torch(x_re, x_im, dim=-1): - x = torch.stack([torch.Tensor(x_re), torch.Tensor(x_im)], dim=-1) - x = torch.view_as_complex(x) - x = torch.fft.irfft2(x, norm="ortho", dim=dim) - return x.numpy() - - -def dft_1d_ms(x, shape, mode, dim=(-1,)): - x = ms.Tensor(x) - x_re = x - x_im = ops.zeros_like(x_re) - dft1_cell = dft1(shape=shape, modes=mode, dim=dim) - x_ft_re, x_ft_im = dft1_cell((x_re, x_im)) - return x_ft_re.asnumpy(), x_ft_im.asnumpy() - - -def dft_2d_ms(x, shape, mode, dim=(-1,)): - x = ms.Tensor(x) - x_re = x - x_im = ops.zeros_like(x_re) - dft2_cell = dft2(shape=shape, modes=mode, dim=dim) - x_ft_re, x_ft_im = dft2_cell((x_re, x_im)) - return x_ft_re.asnumpy(), x_ft_im.asnumpy() - - -def idft_1d_ms(x_re, x_im, shape, mode, dim=(-1)): - x_re = ms.Tensor(x_re) - x_im = ms.Tensor(x_im) - idft1_cell = idft1(shape=shape, modes=mode, dim=dim) - x_ms, _ = idft1_cell((x_re, x_im)) - return x_ms.asnumpy() - - -def idft_2d_ms(x_re, x_im, shape, mode, dim=(-1)): - x_re = ms.Tensor(x_re) - x_im = ms.Tensor(x_im) - idft2_cell = idft2(shape=shape, modes=mode, dim=dim) - x_ms, _ = idft2_cell((x_re, x_im)) - return x_ms.asnumpy() - - -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend910b_training -@pytest.mark.env_onecard -def test_dft1d(): - """ - Feature: Test dft1d in platform gpu and ascend. - Description: None. - Expectation: Success or throw AssertionError. - Torch problem, need to adaptive 910B - """ - x = np.random.randn(1, 6, 8, 2) - x_re_torch1d, x_im_torch1d = dft_1d_torch(x, dim=-2) - x_re_ms1d, x_im_ms1d = dft_1d_ms(x, shape=(8,), mode=5, dim=(-2,)) - - x_torch1d = idft_1d_torch(x_re_torch1d, x_im_torch1d, dim=-2) - x_ms1d = idft_1d_ms(x_re_ms1d, x_im_ms1d, shape=(8,), mode=5, dim=(-2,)) - - assert np.sum(x_torch1d - x_ms1d) < 0.001 - - -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend910b_training -@pytest.mark.env_onecard -def test_dft2d(): - """ - Feature: Test dft2d in platform gpu and ascend. - Description: None. - Expectation: Success or throw AssertionError. - Torch problem, need to adaptive 910B - """ - x = np.random.randn(1, 6, 8, 2) - x_re_torch2d, x_im_torch2d = dft_2d_torch(x, dim=(-3, -2)) - x_re_ms2d, x_im_ms2d = dft_2d_ms(x, shape=(6, 8), mode=(3, 5), dim=(-3, -2)) - - x_torch2d = idft_2d_torch(x_re_torch2d, x_im_torch2d, dim=(-3, -2)) - x_ms2d = idft_2d_ms(x_re_ms2d, x_im_ms2d, shape=(6, 8), mode=(3, 5), dim=(-3, -2)) - - assert np.sum(x_torch2d - x_ms2d) < 0.001 -- Gitee From 6ff9b89b02841db5c7b62ba9b1b88a5aeffcdf9a Mon Sep 17 00:00:00 2001 From: goto Date: Mon, 25 Aug 2025 08:47:59 +0800 Subject: [PATCH 2/2] modify reverse --- MindFlow/mindflow/core/fourier.py | 44 +++++++++++++++++++------------ 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/MindFlow/mindflow/core/fourier.py b/MindFlow/mindflow/core/fourier.py index 64f980668..f1d29a13c 100644 --- a/MindFlow/mindflow/core/fourier.py +++ b/MindFlow/mindflow/core/fourier.py @@ -51,28 +51,34 @@ class MyRoll(nn.Cell): class MyFlip(nn.Cell): ''' Custom defined flip operator to avoid bug in MindSpore ''' - def __init__(self): + def __init__(self, shape=None, compute_dtype=ms.float32): super().__init__() - msver = tuple([int(s) for s in ms.__version__.split('.')]) - if msver <= (2, 4, 0) and \ - ms.get_context('device_target') == 'Ascend' and \ - ms.get_context('mode') == ms.PYNATIVE_MODE: - self.fliper = None + if shape: + shape = np.atleast_1d(shape).astype(int).tolist() + self.rev_mats = [ms.Tensor(np.eye(n)[::-1], dtype=compute_dtype) for n in shape] + self.cast = P.Cast() else: - self.fliper = mint.flip + self.rev_mats = None def construct(self, x, dims): ''' same as mint.flip ''' dims = np.atleast_1d(dims).astype(int).tolist() - if self.fliper: - return self.fliper(x, dims) + if self.rev_mats is not None: + for i, m in enumerate(self.rev_mats): + self.rev_mats[i] = self.cast(m, x.dtype) + + for i, j in enumerate(dims): + x = ops.swapaxes(x, j, -1) + mat = None + if self.rev_mats is None: + mat = ms.Tensor(np.eye(x.shape[-1])[::-1], dtype=x.dtype) + else: + mat = self.rev_mats[i] + x = mint.matmul(x, mat) # todo: 求导有问题 + x = ops.swapaxes(x, j, -1) - for j in dims: - x = ops.swapaxes(x, j, 0) - x = x[::-1] - x = ops.swapaxes(x, j, 0) return x @@ -496,7 +502,7 @@ class DCT(nn.Cell): self.cosw = ops.cos(w) self.sinw = ops.sin(w) - self.fliper = MyFlip() + self.fliper = MyFlip((n // 2,), compute_dtype) def construct(self, a): ''' perform 1-dimensional DCT on real tensor ''' @@ -550,7 +556,7 @@ class IDCT(nn.Cell): self.cosw = ops.cos(w) self.sinw = ops.sin(w) - self.fliper = MyFlip() + self.fliper = MyFlip((n // 2,), compute_dtype) def construct(self, a): ''' perform 1-dimensional iDCT on real tensor ''' @@ -609,9 +615,11 @@ class DST(nn.Cell): multiplier[..., 1::2] *= -1 self.multiplier = Tensor(multiplier, dtype=compute_dtype) + self.fliper = MyFlip((n,), compute_dtype) + def construct(self, a): ''' perform 1-dimensional DST on real tensor ''' - return self.dft_cell.fliper(self.dft_cell(a * self.multiplier), dims=-1) + return self.fliper(self.dft_cell(a * self.multiplier), dims=-1) class IDST(nn.Cell): @@ -651,6 +659,8 @@ class IDST(nn.Cell): multiplier[..., 1::2] *= -1 self.multiplier = Tensor(multiplier, dtype=compute_dtype) + self.fliper = MyFlip((n,), compute_dtype) + def construct(self, a): ''' perform 1-dimensional iDST on real tensor ''' - return self.dft_cell(self.dft_cell.fliper(a, dims=-1)) * self.multiplier + return self.dft_cell(self.fliper(a, dims=-1)) * self.multiplier -- Gitee