diff --git a/mindscience/sciops/__init__.py b/mindscience/sciops/__init__.py index 2220d4f499b98524fd237e5573211a8551c9818c..7341ea2f530f2b3d6d756d365adf5fe2bce19a65 100644 --- a/mindscience/sciops/__init__.py +++ b/mindscience/sciops/__init__.py @@ -18,3 +18,5 @@ init from .fourier import RDFTn, IRDFTn, DFTn, IDFTn, DCT, IDCT, DST, IDST __all__ = ["RDFTn", "IRDFTn", "DFTn", "IDFTn", "DCT", "IDCT", "DST", "IDST"] +from .fft import * +__all__.extend(fft.__all__) diff --git a/mindscience/sciops/fft/__init__.py b/mindscience/sciops/fft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c89b6bb90d309503dd9727052694aacfd4eb9cb7 --- /dev/null +++ b/mindscience/sciops/fft/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" + +from .asd_fft_custom_op import * + +__all__ = ["set_fft_cache_size", "asd_fftn", "asd_ifftn", "asd_rfftn", "asd_irfftn", + "asd_fft", "asd_ifft", "asd_rfft", "asd_irfft", "asd_fft2", "asd_ifft2", "asd_rfft2", "asd_irfft2", + "ASD_FFT", "ASD_IFFT", "ASD_RFFT", "ASD_IRFFT", "ASD_FFT2D", "ASD_IFFT2D", "ASD_RFFT2D", "ASD_IRFFT2D"] diff --git a/mindscience/sciops/fft/asd_fft_custom_op.py b/mindscience/sciops/fft/asd_fft_custom_op.py new file mode 100644 index 0000000000000000000000000000000000000000..bf01799d65f8b97d7621c0f21f00ee49f325bdc9 --- /dev/null +++ b/mindscience/sciops/fft/asd_fft_custom_op.py @@ -0,0 +1,934 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +''' provide asd fft custom operators python binding ''' +import os +import mindspore as ms +from mindspore import nn, ops, dtype as mstype, mint +from mindspore.ops import DataType, CustomRegOp, CustomOpBuilder + +_asd_fft_op = None +def _get_asd_fft_op(): + global _asd_fft_op + if _asd_fft_op is not None: + return _asd_fft_op + cur_asd_fft_op_source_file = os.path.join(os.path.dirname(__file__), "asd_fft_op_ext.cpp") + _asd_fft_op = CustomOpBuilder("asd_fft_op", cur_asd_fft_op_source_file, backend="Ascend", enable_asdsip=True).load() + return _asd_fft_op + +class CustomReal(nn.Cell): + r""" + Custom real part extraction operator for complex tensors. + + This operator extracts the real part from complex tensors, used for compatibility + with different versions of MindSpore. + + Args: + None + + Inputs: + - **x** (Tensor): Input complex tensor with data type complex64. + + Outputs: + - **output** (Tensor): Real part of the input tensor with data type float32. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import CustomReal + >>> x = ms.Tensor([1+2j, 3+4j], ms.complex64) + >>> real_op = CustomReal() + >>> real_part = real_op(x) + >>> print(real_part) + [1. 3.] + """ + def __init__(self): + super(CustomReal, self).__init__() + aclnn_ref_info = CustomRegOp("aclnnReal") \ + .input(0, "x", "required") \ + .output(0, "z", "required") \ + .dtype_format(DataType.C64_Default, DataType.F32_Default) \ + .target("Ascend") \ + .get_op_info() + + self.real = ops.Custom(func="aclnnReal", out_shape=lambda x: x, + out_dtype=mstype.float32, func_type="aot", reg_info=aclnn_ref_info) + def construct(self, x): + return self.real(x) + +class CustomComplex(nn.Cell): + r""" + Custom complex number construction operator. + + This operator constructs complex tensor from real and imaginary tensor. + + Args: + auto_prefix (bool): Whether to automatically generate prefix. Default: True. + flags (dict): Additional flags for the operator. Default: None. + + Inputs: + - **real** (Tensor): Real part tensor with data type float32. + - **imag** (Tensor): Imaginary part tensor with data type float32. + + Outputs: + - **output** (Tensor): Complex tensor with data type complex64. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import CustomComplex + >>> real = ms.Tensor([1.0, 2.0], ms.float32) + >>> imag = ms.Tensor([3.0, 4.0], ms.float32) + >>> complex_op = CustomComplex() + >>> complex_tensor = complex_op(real, imag) + >>> print(complex_tensor) + [1.+3.j 2.+4.j] + """ + def __init__(self, auto_prefix=True, flags=None): + super().__init__(auto_prefix, flags) + aclnn_ref_info = CustomRegOp("aclnnComplex") \ + .input(0, "real", "required") \ + .input(1, "imag", "required") \ + .output(0, "out", "required") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.C64_Default) \ + .target("Ascend") \ + .get_op_info() + + self.complex = ops.Custom(func="aclnnComplex", out_shape=lambda real, imag: real, + out_dtype=mstype.complex64, func_type="aot", reg_info=aclnn_ref_info) + def construct(self, real, imag): + output = self.complex(real, imag) + return output + + +def set_fft_cache_size(cache_size): + r""" + Set cache number of ASD FFT operators to optimize the function call performance. + Without cache, every ASD FFT function call will dlopen the function symbols from .so dynamically, + which will introduce some overhead. + With cache, the function symbols will be loaded into cache when the first ASD FFT function call is made, + and will not be loaded again in subsequent calls. + + Args: + cache_size (int): Cache number of ASD FFT operators. + + Inputs: + None + + Outputs: + None + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> from mindscience.sciops.fft import set_cache_size + >>> set_cache_size(1024) # Set 1024 fft operators cache number + """ + _get_asd_fft_op().asd_set_cache_size(cache_size) + + +# Use mint.real first, if not exist, fallback to custom operator +try: + _mint_real = mint.real + def _get_real(x): + return _mint_real(x) +except AttributeError: + _real_op = CustomReal() + def _get_real(x): + return _real_op(x) + +# Use mint.imag first, if not exist, fallback to custom operator +try: + _mint_imag = mint.imag + def _get_imag(x): + return _mint_imag(x) +except AttributeError: + _neg1j = ms.tensor(0-1j, dtype=mstype.complex64) + def _get_imag(x): + return _get_real(ops.mul(_neg1j, x)) + +def _get_r2c_alf(xr): + n = xr.shape[-1] + m = n // 2 + 1 + k = (n + 1) // 2 + alf = ops.ones(m, xr.dtype) + # alf[1:k] = 2 but we need scale a/alf, so just set alf[1:k] = 0.5, then do mul(dreal, alf) + alf[1:k] = 0.5 + return alf + +def _get_c2r_alf(xr, scale_factor): + n = (xr.shape[-1] - 1) * 2 + m = n // 2 + 1 + k = (n + 1) // 2 + alf = ops.ones(m, xr.dtype) + # for 1->k set alf = 2.0, others set alf = 1.0 + alf[1:k] = 2.0 + alf = alf.mul_(scale_factor) + return alf + + + +# C2C +class ASD_FFT(nn.Cell): # pylint: disable=invalid-name + r""" + 1D complex-to-complex forward FFT transform using Ascend NPU acceleration. + + This operator performs 1D Fast Fourier Transform on complex input tensors, + optimized for Ascend NPU hardware acceleration. + + Args: + None + + Inputs: + - **xr** (Tensor): Real part of input complex tensor with data type float32. + - **xi** (Tensor): Imaginary part of input complex tensor with data type float32. + + Outputs: + - **yr** (Tensor): Real part of output complex tensor with data type float32. + - **yi** (Tensor): Imaginary part of output complex tensor with data type float32. + + Raises: + ValueError: If input tensor data type is not float32. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import ASD_FFT + >>> xr = ms.Tensor([[1.0, 2.0, 3.0, 4.0]], ms.float32) + >>> xi = ms.Tensor([[0.0, 0.0, 0.0, 0.0]], ms.float32) + >>> asd_fft = ASD_FFT() + >>> yr, yi = asd_fft(xr, xi) + >>> print(yr.shape) + (1, 4) + >>> print(yi.shape) + (1, 4) + """ + def __init__(self): + super(ASD_FFT, self).__init__() + self.asd_fft_op = _get_asd_fft_op().asd_fft_1d + self.make_complex = CustomComplex() + self.used_bprop_inputs = [] + + def get_fft_size_and_scale(self, xr): + return xr.shape[-1], None + + + def construct(self, xr, xi): + return self.forward(xr, xi) + + def forward(self, xr, xi=None): + """forward""" + if xr.dtype != mstype.float32 or (xi is not None and xi.dtype != mstype.float32): + raise ValueError("ASD_FFT Input tensor must be float32") + org_shape = list(xr.shape) + + # Unify to two dimensions + batch_size = 1 + for i in range(len(org_shape) - 1): + batch_size *= org_shape[i] + if len(org_shape) != 2: + xr = mint.reshape(xr, (batch_size, org_shape[-1])) + xi = mint.reshape(xi, (batch_size, org_shape[-1])) if xi is not None else None + + fft_size, scale_factor = self.get_fft_size_and_scale(xr) + x = self.make_complex(xr, xi) if xi is not None else xr + output = self.asd_fft_op(x, xr.shape[0], fft_size) + + if scale_factor is not None: + output.mul_(scale_factor) + + if org_shape != list(output.shape): + org_shape[-1] = output.shape[-1] + output = mint.reshape(output, tuple(org_shape)) + # If output is not complex, return directly + if not ops.is_complex(output): + return output + + return _get_real(output), _get_imag(output) + + def bprop(self, xr, xi, out, dout): # pylint: disable=unused-argument + dreal, dimag = dout + dxr, dxi = asd_ifft(dreal, dimag) + return dxr.mul_(dreal.shape[-1]), dxi.mul_(dreal.shape[-1]) + +# C2C +class ASD_IFFT(ASD_FFT): # pylint: disable=invalid-name + r""" + 1D complex-to-complex inverse FFT transform using Ascend NPU acceleration. + + This operator performs 1D Inverse Fast Fourier Transform on complex input tensors, + optimized for Ascend NPU hardware acceleration. + + Args: + None + + Inputs: + - **xr** (Tensor): Real part of input complex tensor with data type float32. + - **xi** (Tensor): Imaginary part of input complex tensor with data type float32. + + Outputs: + - **yr** (Tensor): Real part of output complex tensor with data type float32. + - **yi** (Tensor): Imaginary part of output complex tensor with data type float32. + + Raises: + ValueError: If input tensor data type is not float32. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import ASD_IFFT + >>> xr = ms.Tensor([[1.0, 2.0, 3.0, 4.0]], ms.float32) + >>> xi = ms.Tensor([[0.0, 0.0, 0.0, 0.0]], ms.float32) + >>> asd_ifft = ASD_IFFT() + >>> yr, yi = asd_ifft(xr, xi) + >>> print(yr.shape) + (1, 4) + >>> print(yi.shape) + (1, 4) + """ + def __init__(self): + super(ASD_IFFT, self).__init__() + self.asd_fft_op = _get_asd_fft_op().asd_ifft_1d + self.used_bprop_inputs = [] + + def get_fft_size_and_scale(self, xr): + fft_size = xr.shape[-1] + scale_factor = 1.0 / fft_size + return fft_size, scale_factor + + def bprop(self, xr, xi, out, dout): # pylint: disable=unused-argument + dreal, dimag = dout + dxr, dxi = asd_fft(dreal, dimag) + scale_factor = 1.0 / dreal.shape[-1] + return dxr.mul_(scale_factor), dxi.mul_(scale_factor) + +# R2C +class ASD_RFFT(ASD_FFT): # pylint: disable=invalid-name + r""" + 1D real-to-complex FFT transform using Ascend NPU acceleration. + + This operator performs 1D Real Fast Fourier Transform on real input tensors, + optimized for Ascend NPU hardware acceleration. + + Args: + None + + Inputs: + - **xr** (Tensor): Input real tensor with data type float32. + + Outputs: + - **yr** (Tensor): Real part of output complex tensor with data type float32. + - **yi** (Tensor): Imaginary part of output complex tensor with data type float32. + + Raises: + ValueError: If input tensor data type is not float32. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import ASD_RFFT + >>> xr = ms.Tensor([[1.0, 2.0, 3.0, 4.0]], ms.float32) + >>> asd_rfft = ASD_RFFT() + >>> yr, yi = asd_rfft(xr) + >>> print(yr.shape) + (1, 3) + >>> print(yi.shape) + (1, 3) + """ + def __init__(self): + super(ASD_RFFT, self).__init__() + self.asd_fft_op = _get_asd_fft_op().asd_rfft_1d + self.used_bprop_inputs = [0] + + def construct(self, xr): # pylint: disable=arguments-differ + return self.forward(xr) + + def bprop(self, xr, out, dout): # pylint: disable=arguments-differ, unused-argument + dreal, dimag = dout + alf = _get_r2c_alf(xr) + dxr = asd_irfftn(dreal.mul_(alf), dimag.mul_(alf), ndim=1, n=xr.shape[-1]) + return dxr.mul_(xr.shape[-1]) + +# C2R +class ASD_IRFFT(ASD_FFT): # pylint: disable=invalid-name + r""" + 1D complex-to-real inverse FFT transform using Ascend NPU acceleration. + + This operator performs 1D Inverse Real Fast Fourier Transform on complex input tensors, + optimized for Ascend NPU hardware acceleration. + + Args: + None + + Inputs: + - **xr** (Tensor): Real part of input complex tensor with data type float32. + - **xi** (Tensor): Imaginary part of input complex tensor with data type float32. + + Outputs: + - **yr** (Tensor): Output real tensor with data type float32. + + Raises: + ValueError: If input tensor data type is not float32. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import ASD_IRFFT + >>> xr = ms.Tensor([[1.0, 2.0, 3.0]], ms.float32) + >>> xi = ms.Tensor([[0.0, 0.0, 0.0]], ms.float32) + >>> asd_irfft = ASD_IRFFT() + >>> yr = asd_irfft(xr, xi) + >>> print(yr.shape) + (1, 4) + """ + def __init__(self): + super(ASD_IRFFT, self).__init__() + self.asd_fft_op = _get_asd_fft_op().asd_irfft_1d + self.n = None + self.used_bprop_inputs = [0] + + def get_fft_size_and_scale(self, xr): + if self.n is not None and (self.n // 2 + 1) == xr.shape[-1]: + fft_size = self.n + else: + fft_size = (xr.shape[-1] - 1) * 2 + scale_factor = 1.0 / fft_size + return fft_size, scale_factor + + def set_n(self, n): + self.n = n + + def bprop(self, xr, xi, out, dout): # pylint: disable=unused-argument + dreal = dout + fft_size, scale_factor = self.get_fft_size_and_scale(xr) # pylint: disable=unused-variable + alf = _get_c2r_alf(xr, scale_factor) + dxr, dxi = asd_rfft(dreal) + + return dxr.mul_(alf), dxi.mul_(alf) + +# C2C forward +class ASD_FFT2D(nn.Cell): # pylint: disable=invalid-name + r""" + 2D complex-to-complex forward FFT transform using Ascend NPU acceleration. + + This operator performs 2D Fast Fourier Transform on complex input tensors, + optimized for Ascend NPU hardware acceleration. + + Args: + None + + Inputs: + - **xr** (Tensor): Real part of input complex tensor with data type float32, at least 2D. + - **xi** (Tensor): Imaginary part of input complex tensor with data type float32. + + Outputs: + - **yr** (Tensor): Real part of output complex tensor with data type float32. + - **yi** (Tensor): Imaginary part of output complex tensor with data type float32. + + Raises: + ValueError: If input tensor data type is not float32 or tensor has less than 2 dimensions. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import ASD_FFT2D + >>> xr = ms.Tensor([[[1.0, 2.0], [3.0, 4.0]]], ms.float32) + >>> xi = ms.Tensor([[[0.0, 0.0], [0.0, 0.0]]], ms.float32) + >>> asd_fft2d = ASD_FFT2D() + >>> yr, yi = asd_fft2d(xr, xi) + >>> print(yr.shape) + (1, 2, 2) + >>> print(yi.shape) + (1, 2, 2) + """ + def __init__(self): + super(ASD_FFT2D, self).__init__() + self.asd_fft_op = _get_asd_fft_op().asd_fft_2d + self.make_complex = CustomComplex() + self.used_bprop_inputs = [] + + def get_fft_size_and_scale(self, xr): + return xr.shape[0], xr.shape[1], xr.shape[2], None + + def construct(self, xr, xi): + return self.forward(xr, xi) + + def forward(self, xr, xi=None): + """forward""" + if xr.dtype != mstype.float32 or (xi is not None and xi.dtype != mstype.float32): + raise ValueError("ASD_FFT Input tensor must be float32") + + org_shape = list(xr.shape) + if len(org_shape) < 2: + raise ValueError("2D FFT Input tensor must have at least 2 dimensions") + # Unify to three dimensions + if len(org_shape) == 2: + xr = ops.expand_dims(xr, 0) + xi = ops.expand_dims(xi, 0) if xi is not None else None + elif len(org_shape) > 3: + xr = ops.reshape(xr, (-1, org_shape[-2], org_shape[-1])) + xi = ops.reshape(xi, (-1, org_shape[-2], org_shape[-1])) if xi is not None else None + + batch_size, x_size, y_size, scale_factor = self.get_fft_size_and_scale(xr) + x = self.make_complex(xr, xi) if xi is not None else xr + output = self.asd_fft_op(x, batch_size, x_size, y_size) + + if scale_factor is not None: + output.mul_(scale_factor) + + if org_shape != list(output.shape): + org_shape[-1] = output.shape[-1] + output = ops.reshape(output, tuple(org_shape)) + + # If output is not complex, return directly + if not ops.is_complex(output): + return output + + return _get_real(output), _get_imag(output) + + def bprop(self, xr, xi, out, dout): # pylint: disable=unused-argument + dreal, dimag = dout + dxr, dxi = asd_ifft2(dreal, dimag) + n = dreal.shape[-1] * dreal.shape[-2] + return dxr.mul_(n), dxi.mul_(n) + +# C2C inverse +class ASD_IFFT2D(ASD_FFT2D): # pylint: disable=invalid-name + r""" + 2D complex-to-complex inverse FFT transform using Ascend NPU acceleration. + + This operator performs 2D Inverse Fast Fourier Transform on complex input tensors, + optimized for Ascend NPU hardware acceleration. + + Args: + None + + Inputs: + - **xr** (Tensor): Real part of input complex tensor with data type float32, at least 2D. + - **xi** (Tensor): Imaginary part of input complex tensor with data type float32. + + Outputs: + - **yr** (Tensor): Real part of output complex tensor with data type float32. + - **yi** (Tensor): Imaginary part of output complex tensor with data type float32. + + Raises: + ValueError: If input tensor data type is not float32 or tensor has less than 2 dimensions. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import ASD_IFFT2D + >>> xr = ms.Tensor([[[1.0, 2.0], [3.0, 4.0]]], ms.float32) + >>> xi = ms.Tensor([[[0.0, 0.0], [0.0, 0.0]]], ms.float32) + >>> asd_ifft2d = ASD_IFFT2D() + >>> yr, yi = asd_ifft2d(xr, xi) + >>> print(yr.shape) + (1, 2, 2) + >>> print(yi.shape) + (1, 2, 2) + """ + def __init__(self): + super(ASD_IFFT2D, self).__init__() + self.asd_fft_op = _get_asd_fft_op().asd_ifft_2d + self.used_bprop_inputs = [] + + def get_fft_size_and_scale(self, xr): + batch_size, x_size, y_size, scale_factor = super(ASD_IFFT2D, self).get_fft_size_and_scale(xr) + scale_factor = 1.0 / (x_size * y_size) + return batch_size, x_size, y_size, scale_factor + + def bprop(self, xr, xi, out, dout): # pylint: disable=unused-argument + dreal, dimag = dout + dxr, dxi = asd_fft2(dreal, dimag) + scale_factor = 1.0 / (dreal.shape[-1] * dreal.shape[-2]) + return dxr.mul_(scale_factor), dxi.mul_(scale_factor) + +# R2C forward +class ASD_RFFT2D(ASD_FFT2D): # pylint: disable=invalid-name + r""" + 2D real-to-complex FFT transform using Ascend NPU acceleration. + + This operator performs 2D Real Fast Fourier Transform on real input tensors, + optimized for Ascend NPU hardware acceleration. + + Args: + None + + Inputs: + - **xr** (Tensor): Input real tensor with data type float32, at least 2D. + + Outputs: + - **yr** (Tensor): Real part of output complex tensor with data type float32. + - **yi** (Tensor): Imaginary part of output complex tensor with data type float32. + + Raises: + ValueError: If input tensor data type is not float32 or tensor has less than 2 dimensions. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import ASD_RFFT2D + >>> xr = ms.Tensor([[[1.0, 2.0], [3.0, 4.0]]], ms.float32) + >>> asd_rfft2d = ASD_RFFT2D() + >>> yr, yi = asd_rfft2d(xr) + >>> print(yr.shape) + (1, 2, 2) + >>> print(yi.shape) + (1, 2, 2) + """ + def __init__(self): + super(ASD_RFFT2D, self).__init__() + self.asd_fft_op = _get_asd_fft_op().asd_rfft_2d + self.used_bprop_inputs = [0] + + def construct(self, xr): # pylint: disable=arguments-differ + return self.forward(xr) + + def bprop(self, xr, out, dout): # pylint: disable=arguments-differ, unused-argument + dreal, dimag = dout + alf = _get_r2c_alf(xr) + dreal.mul_(alf) + dimag.mul_(alf) + dxr = asd_irfftn(dreal, dimag, ndim=2, n=xr.shape[-1]) + n = xr.shape[-1] * xr.shape[-2] * 1.0 + return dxr.mul_(n) + +# C2R inverse +class ASD_IRFFT2D(ASD_FFT2D): # pylint: disable=invalid-name + r""" + 2D complex-to-real inverse FFT transform using Ascend NPU acceleration. + + This operator performs 2D Inverse Real Fast Fourier Transform on complex input tensors, + optimized for Ascend NPU hardware acceleration. + + Args: + None + + Inputs: + - **xr** (Tensor): Real part of input complex tensor with data type float32, at least 2D. + - **xi** (Tensor): Imaginary part of input complex tensor with data type float32. + + Outputs: + - **yr** (Tensor): Output real tensor with data type float32. + + Raises: + ValueError: If input tensor data type is not float32 or tensor has less than 2 dimensions. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import ASD_IRFFT2D + >>> xr = ms.Tensor([[[1.0, 2.0], [3.0, 4.0]]], ms.float32) + >>> xi = ms.Tensor([[[0.0, 0.0], [0.0, 0.0]]], ms.float32) + >>> asd_irfft2d = ASD_IRFFT2D() + >>> yr = asd_irfft2d(xr, xi) + >>> print(yr.shape) + (1, 2, 4) + """ + def __init__(self): + super(ASD_IRFFT2D, self).__init__() + self.asd_fft_op = _get_asd_fft_op().asd_irfft_2d + self.n = None + self.used_bprop_inputs = [0] + + def get_fft_size_and_scale(self, xr): + batch_size = xr.shape[0] + x_size = xr.shape[1] + y_size = xr.shape[2] + if self.n is not None and (self.n // 2 + 1) == y_size: + output_last_dim = self.n + else: + output_last_dim = (y_size - 1) * 2 + scale_factor = 1.0 / (x_size * output_last_dim) + return batch_size, x_size, output_last_dim, scale_factor + + def set_n(self, n): + self.n = n + + def bprop(self, xr, xi, out, dout): # pylint: disable=unused-argument + dreal = dout + dxr, dxi = asd_rfft2(dreal) + x_size = xr.shape[-2] + y_size = xr.shape[-1] + if self.n is not None and (self.n // 2 + 1) == y_size: + output_last_dim = self.n + else: + output_last_dim = (y_size - 1) * 2 + scale_factor = 1.0 / (x_size * output_last_dim) + alf = _get_c2r_alf(xr, scale_factor) + return dxr.mul_(alf), dxi.mul_(alf) + +# 全局变量用于存储单例实例 +_asd_fft_instance = None +_asd_ifft_instance = None +_asd_rfft_instance = None +_asd_irfft_instance = None +_asd_fft2_instance = None +_asd_ifft2_instance = None +_asd_rfft2_instance = None +_asd_irfft2_instance = None + +# 延迟初始化的FFT操作符函数 +def asd_fft(*args, **kwargs): + """延迟初始化的ASD_FFT调用""" + global _asd_fft_instance + if _asd_fft_instance is None: + _asd_fft_instance = ASD_FFT() + return _asd_fft_instance(*args, **kwargs) + +def asd_ifft(*args, **kwargs): + """延迟初始化的ASD_IFFT调用""" + global _asd_ifft_instance + if _asd_ifft_instance is None: + _asd_ifft_instance = ASD_IFFT() + return _asd_ifft_instance(*args, **kwargs) + +def asd_rfft(*args, **kwargs): + """延迟初始化的ASD_RFFT调用""" + global _asd_rfft_instance + if _asd_rfft_instance is None: + _asd_rfft_instance = ASD_RFFT() + return _asd_rfft_instance(*args, **kwargs) + +def asd_irfft(*args, **kwargs): + """延迟初始化的ASD_IRFFT调用""" + global _asd_irfft_instance + if _asd_irfft_instance is None: + _asd_irfft_instance = ASD_IRFFT() + return _asd_irfft_instance(*args, **kwargs) + +def asd_fft2(*args, **kwargs): + """延迟初始化的ASD_FFT2D调用""" + global _asd_fft2_instance + if _asd_fft2_instance is None: + _asd_fft2_instance = ASD_FFT2D() + return _asd_fft2_instance(*args, **kwargs) + +def asd_ifft2(*args, **kwargs): + """延迟初始化的ASD_IFFT2D调用""" + global _asd_ifft2_instance + if _asd_ifft2_instance is None: + _asd_ifft2_instance = ASD_IFFT2D() + return _asd_ifft2_instance(*args, **kwargs) + +def asd_rfft2(*args, **kwargs): + """延迟初始化的ASD_RFFT2D调用""" + global _asd_rfft2_instance + if _asd_rfft2_instance is None: + _asd_rfft2_instance = ASD_RFFT2D() + return _asd_rfft2_instance(*args, **kwargs) + +def asd_irfft2(*args, **kwargs): + """延迟初始化的ASD_IRFFT2D调用""" + global _asd_irfft2_instance + if _asd_irfft2_instance is None: + _asd_irfft2_instance = ASD_IRFFT2D() + return _asd_irfft2_instance(*args, **kwargs) + +def asd_fftn(xr, xi, ndim=1): + r""" + N-dimensional complex-to-complex forward FFT transform using Ascend NPU acceleration. + + This function provides a unified interface for 1D and 2D complex-to-complex FFT transforms, + optimized for Ascend NPU hardware acceleration. + + Args: + xr (Tensor): Real part of input complex tensor with data type float32. + xi (Tensor): Imaginary part of input complex tensor with data type float32. + ndim (int): Number of dimensions to transform. Default: 1. Only support 1 and 2. + + Inputs: + - **xr** (Tensor): Real part of input complex tensor with data type float32. + - **xi** (Tensor): Imaginary part of input complex tensor with data type float32. + - **ndim** (int): Number of dimensions to transform. Default: 1. + + Outputs: + - **yr** (Tensor): Real part of output complex tensor with data type float32. + - **yi** (Tensor): Imaginary part of output complex tensor with data type float32. + + Raises: + ValueError: If ndim is not 1 or 2. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import asd_fftn + >>> xr = ms.Tensor([[1.0, 2.0, 3.0, 4.0]], ms.float32) + >>> xi = ms.Tensor([[0.0, 0.0, 0.0, 0.0]], ms.float32) + >>> yr, yi = asd_fftn(xr, xi, ndim=1) + >>> print(yr.shape) + (1, 4) + >>> print(yi.shape) + (1, 4) + """ + if ndim == 1: + return asd_fft(xr, xi) + if ndim == 2: + return asd_fft2(xr, xi) + raise ValueError(f"asd_fftn Unsupported dimension: {ndim}, only support 1D and 2D") + +def asd_ifftn(xr, xi, ndim=1): + r""" + N-dimensional complex-to-complex inverse FFT transform using Ascend NPU acceleration. + + This function provides a unified interface for 1D and 2D complex-to-complex inverse FFT transforms, + optimized for Ascend NPU hardware acceleration. + + Args: + xr (Tensor): Real part of input complex tensor with data type float32. + xi (Tensor): Imaginary part of input complex tensor with data type float32. + ndim (int): Number of dimensions to transform. Default: 1. Only support 1 and 2. + + Inputs: + - **xr** (Tensor): Real part of input complex tensor with data type float32. + - **xi** (Tensor): Imaginary part of input complex tensor with data type float32. + - **ndim** (int): Number of dimensions to transform. Default: 1. + + Outputs: + - **yr** (Tensor): Real part of output complex tensor with data type float32. + - **yi** (Tensor): Imaginary part of output complex tensor with data type float32. + + Raises: + ValueError: If ndim is not 1 or 2. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import asd_ifftn + >>> xr = ms.Tensor([[1.0, 2.0, 3.0, 4.0]], ms.float32) + >>> xi = ms.Tensor([[0.0, 0.0, 0.0, 0.0]], ms.float32) + >>> yr, yi = asd_ifftn(xr, xi, ndim=1) + >>> print(yr.shape) + (1, 4) + >>> print(yi.shape) + (1, 4) + """ + if ndim == 1: + return asd_ifft(xr, xi) + if ndim == 2: + return asd_ifft2(xr, xi) + raise ValueError(f"asd_ifftn Unsupported dimension: {ndim}, only support 1D and 2D") + +def asd_rfftn(xr, ndim=1): + r""" + N-dimensional real-to-complex FFT transform using Ascend NPU acceleration. + + This function provides a unified interface for 1D and 2D real-to-complex FFT transforms, + optimized for Ascend NPU hardware acceleration. + + Args: + xr (Tensor): Input real tensor with data type float32. + ndim (int): Number of dimensions to transform. Default: 1. Only support 1 and 2. + + Inputs: + - **xr** (Tensor): Input real tensor with data type float32. + - **ndim** (int): Number of dimensions to transform. Default: 1. + + Outputs: + - **yr** (Tensor): Real part of output complex tensor with data type float32. + - **yi** (Tensor): Imaginary part of output complex tensor with data type float32. + + Raises: + ValueError: If ndim is not 1 or 2. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import asd_rfftn + >>> xr = ms.Tensor([[1.0, 2.0, 3.0, 4.0]], ms.float32) + >>> yr, yi = asd_rfftn(xr, ndim=1) + >>> print(yr.shape) + (1, 3) + >>> print(yi.shape) + (1, 3) + """ + if ndim == 1: + return asd_rfft(xr) + if ndim == 2: + return asd_rfft2(xr) + raise ValueError(f"asd_rfftn Unsupported dimension: {ndim}, only support 1D and 2D") + +def asd_irfftn(xr, xi, n=None, ndim=1): + r""" + N-dimensional complex-to-real inverse FFT transform using Ascend NPU acceleration. + + This function provides a unified interface for 1D and 2D complex-to-real inverse FFT transforms, + optimized for Ascend NPU hardware acceleration. + + Args: + xr (Tensor): Real part of input complex tensor with data type float32. + xi (Tensor): Imaginary part of input complex tensor with data type float32. + n (int, optional): Length of the output tensor. Default: None. + ndim (int): Number of dimensions to transform. Default: 1. Only support 1 and 2. + + Inputs: + - **xr** (Tensor): Real part of input complex tensor with data type float32. + - **xi** (Tensor): Imaginary part of input complex tensor with data type float32. + - **n** (int, optional): Length of the output tensor. Default: None. + - **ndim** (int): Number of dimensions to transform. Default: 1. + + Outputs: + - **yr** (Tensor): Output real tensor with data type float32. + + Raises: + ValueError: If ndim is not 1 or 2. + + Supported Platforms: + ``Ascend`` ``Pynative`` + + Examples: + >>> import mindspore as ms + >>> from mindscience.sciops.fft import asd_irfftn + >>> xr = ms.Tensor([[1.0, 2.0, 3.0]], ms.float32) + >>> xi = ms.Tensor([[0.0, 0.0, 0.0]], ms.float32) + >>> yr = asd_irfftn(xr, xi, ndim=1) + >>> print(yr.shape) + (1, 4) + """ + if ndim == 1: + instance = ASD_IRFFT() + instance.set_n(n) + return instance(xr, xi) + if ndim == 2: + instance = ASD_IRFFT2D() + instance.set_n(n) + return instance(xr, xi) + raise ValueError(f"asd_irfftn Unsupported dimension: {ndim}, only support 1D and 2D") diff --git a/mindscience/sciops/fft/asd_fft_op_ext.cpp b/mindscience/sciops/fft/asd_fft_op_ext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d615c4ede774cbf2cf1add3401bdc6d99527e95f --- /dev/null +++ b/mindscience/sciops/fft/asd_fft_op_ext.cpp @@ -0,0 +1,219 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "ms_extension/api.h" +using std::string; +using ms::pynative::FFTParam; +using ms::pynative::asdFftDirection; +using ms::pynative::asdFft1dDimType; +using ms::pynative::asdFftType; + +ms::Tensor GetResultTensor(const ms::Tensor &t, const FFTParam ¶m) { + auto type_id = mindspore::TypeId::kNumberTypeComplex64; + if (param.fftType == asdFftType::ASCEND_FFT_C2R) { + type_id = mindspore::TypeId::kNumberTypeFloat32; + } + // for fft and ifft, the output shape is equal to input shape + ShapeVector out_shape(t.shape()); + + // for rfft, the output shape is (batch_size, x_size, (y_size // 2) + 1) + if (param.fftType == asdFftType::ASCEND_FFT_R2C && param.direction == asdFftDirection::ASCEND_FFT_FORWARD) { + out_shape.back() = (out_shape.back() / 2) + 1; + } + // for irfft, the output shape is (batch_size, x_size, fftYSize or fftXSize1d) + if (param.fftType == asdFftType::ASCEND_FFT_C2R && param.direction == asdFftDirection::ASCEND_FFT_BACKWARD) { + out_shape.back() = param.fftYSize == 0 ? param.fftXSize : param.fftYSize; + } + return ms::Tensor(type_id, out_shape); +} + +ms::Tensor exec_asdsip_fft(const string &op_name, const FFTParam ¶m, const ms::Tensor &input) { + MS_LOG(INFO) << "Run device task LaunchAsdSipFFT start, op_name: " << op_name << ", param.fftX: " << param.fftXSize + << ", param.fftY: " << param.fftYSize << ", param.fftType: " << param.fftType << ", param.direction: " + << param.direction << ", param.batchSize: " << param.batchSize << ", param.dimType: " << param.dimType; + auto output = GetResultTensor(input, param); + ms::pynative::RunAsdSipFFTOp(op_name, param, input, output); + MS_LOG(INFO) << "Run device task LaunchAsdSipFFT end"; + + return output; +} + +auto pyboost_npu_set_cache_size(int64_t cache_size) { + MS_LOG(INFO) << "Set cache size for NPU FFT, cache_size: " << cache_size; + ms::pynative::AsdSipFFTOpRunner::SetCacheSize(cache_size); + MS_LOG(INFO) << "Set cache size for NPU FFT end"; +} + +// 1D FFT forward +auto pyboost_npu_fft_1d(const ms::Tensor &input, int64_t batch_size, int64_t x_size) { + FFTParam param; + param.fftXSize = x_size; + param.fftYSize = 0; + param.fftType = asdFftType::ASCEND_FFT_C2C; + param.direction = asdFftDirection::ASCEND_FFT_FORWARD; + param.batchSize = batch_size; + param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; + const string op_name = "asdFftExecC2C"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 1D FFT inverse +auto pyboost_npu_ifft_1d(const ms::Tensor &input, int64_t batch_size, int64_t x_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = 0; + param.fftType = asdFftType::ASCEND_FFT_C2C; + param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; + param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; + const string op_name = "asdFftExecC2C"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 1D RFFT forward +auto pyboost_npu_rfft_1d(const ms::Tensor &input, int64_t batch_size, int64_t x_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = 0; + param.fftType = asdFftType::ASCEND_FFT_R2C; + param.direction = asdFftDirection::ASCEND_FFT_FORWARD; + param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; + const string op_name = "asdFftExecR2C"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 1D IRFFT inverse +auto pyboost_npu_irfft_1d(const ms::Tensor &input, int64_t batch_size, int64_t x_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = 0; + param.fftType = asdFftType::ASCEND_FFT_C2R; + param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; + param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; + const string op_name = "asdFftExecC2R"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 2D FFT forward +auto pyboost_npu_fft_2d(const ms::Tensor &input, int64_t batch_size, int64_t x_size, int64_t y_size) { + ms::pynative::FFTParam param; + param.fftXSize = x_size; + param.fftYSize = y_size; + param.fftType = asdFftType::ASCEND_FFT_C2C; + param.direction = asdFftDirection::ASCEND_FFT_FORWARD; + param.batchSize = batch_size; + const string op_name = "asdFftExecC2C"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 2D IFFT inverse +auto pyboost_npu_ifft_2d(const ms::Tensor &input, int64_t batch_size, int64_t x_size, int64_t y_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = y_size; + param.fftType = asdFftType::ASCEND_FFT_C2C; + param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; + const string op_name = "asdFftExecC2C"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 2D RFFT forward +auto pyboost_npu_rfft_2d(const ms::Tensor &input, int64_t batch_size, int64_t x_size, int64_t y_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = y_size; + param.fftType = asdFftType::ASCEND_FFT_R2C; + param.direction = asdFftDirection::ASCEND_FFT_FORWARD; + const string op_name = "asdFftExecR2C"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 2D IRFFT inverse +auto pyboost_npu_irfft_2d(const ms::Tensor &input, int64_t batch_size, int64_t x_size, int64_t y_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = y_size; + param.fftType = asdFftType::ASCEND_FFT_C2R; + param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; + const string op_name = "asdFftExecC2R"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// Python binding module +PYBIND11_MODULE(MS_EXTENSION_NAME, m) { + m.doc() = "NPU FFT operations for MindSpore"; + + m.def("asd_set_cache_size", &pyboost_npu_set_cache_size, "Set cache size for NPU FFT", + pybind11::arg("cache_size")); + + // 1D FFT + m.def("asd_fft_1d", &pyboost_npu_fft_1d, "1D FFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size")); + + // 1D IFFT + m.def("asd_ifft_1d", &pyboost_npu_ifft_1d, "1D IFFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size")); + + // 1D RFFT + m.def("asd_rfft_1d", &pyboost_npu_rfft_1d, "1D RFFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size")); + + // 1D IRFFT + m.def("asd_irfft_1d", &pyboost_npu_irfft_1d, "1D IRFFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size")); + + // 2D FFT + m.def("asd_fft_2d", &pyboost_npu_fft_2d, "2D FFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size"), + pybind11::arg("y_size")); + + // 2D IFFT + m.def("asd_ifft_2d", &pyboost_npu_ifft_2d, "2D IFFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size"), + pybind11::arg("y_size")); + + // 2D RFFT + m.def("asd_rfft_2d", &pyboost_npu_rfft_2d, "2D RFFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size"), + pybind11::arg("y_size")); + + // 2D IRFFT + m.def("asd_irfft_2d", &pyboost_npu_irfft_2d, "2D IRFFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size"), + pybind11::arg("y_size")); +} diff --git a/mindscience/sciops/fourier.py b/mindscience/sciops/fourier.py index f1d29a13c3eb3f5525a92ea3f5267c621dc61639..d4e5ecc22837c8bac3f0efa62ed1404eed0a9f31 100644 --- a/mindscience/sciops/fourier.py +++ b/mindscience/sciops/fourier.py @@ -129,7 +129,7 @@ class _DFT1d(nn.Cell): super().__init__() self.n = n - self.dft_mat = scipy.linalg.dft(n, scale=scale) + dft_mat = scipy.linalg.dft(n, scale=scale) self.last_index = last_index self.inv = inv self.odd = bool(n % 2) @@ -139,11 +139,11 @@ class _DFT1d(nn.Cell): self.compute_dtype = compute_dtype # generate DFT matrix for positive and negative frequencies - dft_mat_mode = self.dft_mat[:, :self.mode_upper] + dft_mat_mode = dft_mat[:, :self.mode_upper] self.a_re_upper = Tensor(dft_mat_mode.real, dtype=compute_dtype) self.a_im_upper = Tensor(dft_mat_mode.imag, dtype=compute_dtype) - dft_mat_mode = self.dft_mat[:, -self.mode_lower:] + dft_mat_mode = dft_mat[:, -self.mode_lower:] self.a_re_lower = Tensor(dft_mat_mode.real, dtype=compute_dtype) self.a_im_lower = Tensor(dft_mat_mode.imag, dtype=compute_dtype) @@ -164,7 +164,7 @@ class _DFT1d(nn.Cell): # last axis is real-transformed, so the inverse is conjugate of the positive frequencies if last_index: mode_res = min(self.mode_lower, self.mode_upper - 1) - dft_mat_res = self.dft_mat[:, -mode_res:] + dft_mat_res = dft_mat[:, -mode_res:] a_re_res = MyFlip()(Tensor(dft_mat_res.real, dtype=compute_dtype), dims=-1) a_im_res = MyFlip()(Tensor(dft_mat_res.imag, dtype=compute_dtype), dims=-1) diff --git a/setup.py b/setup.py index 477f77e07c710bd880b1a2b1ea6bd22f86a5afaa..c518711c61ace4f524d12e6ea987afd8b2e6fa6d 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ package_data = { 'include/*' 'build_info.txt' ], + 'mindscience.sciops.fft': ['*.cpp'], '_c_minddata': ['lib_c_minddata*.so'] } @@ -73,4 +74,3 @@ setup( include_package_data=True, install_requires=required_package, classifiers=['License :: OSI Approved :: Apache Software License']) - diff --git a/tests/sciops/test_asd_fft.py b/tests/sciops/test_asd_fft.py new file mode 100644 index 0000000000000000000000000000000000000000..066e74a30d9576b9cd7c95bf8294897a57453d00 --- /dev/null +++ b/tests/sciops/test_asd_fft.py @@ -0,0 +1,376 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Optimizers Test Case""" +import random + +import numpy as np +import mindspore as ms +import pytest + +# pylint: disable=ungrouped-imports +from mindspore import set_seed, ops +from mindspore.profiler import ProfilerLevel, ProfilerActivity, AicoreMetrics, ExportType + +from mindscience.sciops import DFTn, IDFTn, RDFTn, IRDFTn, asd_fftn, asd_ifftn, asd_rfftn, asd_irfftn + +set_seed(0) +np.random.seed(0) +random.seed(0) +FP32_RTOL = 1e-3 + +def loss_func_c(yr, yi): + return ops.sum(yr * yr + 2 * yi * yi) + +def loss_grad_np_c(y): + return 2 * y.real + 4j * y.imag + +def loss_func_r(yr): + return ops.sum(yr * yr * yr) + +def loss_grad_np_r(y): + return 3 * y * y + +def forwad_fn_c2c(xr, xi, dim, dft_cell): + if dim is not None: + br, bi = dft_cell(xr, xi, ndim=dim) + else: + br, bi = dft_cell(xr, xi) + return loss_func_c(br, bi) + +def forwad_fn_r2c(xr, dim, dft_cell): + if dim is not None: + br, bi = dft_cell(xr, ndim=dim) + else: + br, bi = dft_cell(xr) + return loss_func_c(br, bi) + +def forwad_fn_c2r(xr, xi, dim, dft_cell): + if dim is not None: + br = dft_cell(xr, xi, ndim=dim) + else: + br = dft_cell(xr, xi) + return loss_func_r(br) + +def gen_input(shape=(2, 16, 16), rand_test=True): + ''' Generate random or deterministic tensor for input of the tests + ''' + a = np.random.rand(*shape) + 1j * np.random.rand(*shape) + if not rand_test: + a = sum([np.arange(n).reshape([n] + [1] * j) for j, n in enumerate(shape[::-1])]) + 1j * \ + sum([np.arange(n).reshape([n] + [1] * j) for j, n in enumerate(shape[::-1])]) + ar, ai = (ms.Tensor(a.real, dtype=ms.float32), ms.Tensor(a.imag, dtype=ms.float32)) + return a, ar, ai + +def cal_error(name, ar, ai, br, bi): + ''' + ar, ai, br, bi are all numpy arrays, calculate the max absolute error, max relative error, and mean relative error + ''' + print(f"{name} ar.shape: ", ar.shape) + print(f"{name} br.shape: ", br.shape) + if ai is not None and bi is not None: + print(f"{name} ai.shape: ", ai.shape) + print(f"{name} bi.shape: ", bi.shape) + abs_error_real = np.abs(ar - br) + rel_error_real = abs_error_real / (np.abs(ar) + 1e-10) + if ai is not None and bi is not None: + abs_error_imag = np.abs(ai - bi) + rel_error_imag = abs_error_imag / (np.abs(ai) + 1e-10) + max_abs_error = max(np.max(abs_error_real), np.max(abs_error_imag)) + max_rel_error = max(np.max(rel_error_real), np.max(rel_error_imag)) + mean_rel_error = (np.mean(rel_error_real) + np.mean(rel_error_imag)) / 2 + else: + abs_error_imag = None + rel_error_imag = None + max_abs_error = np.max(abs_error_real) + max_rel_error = np.max(rel_error_real) + mean_rel_error = np.mean(rel_error_real) + + print(f"{name} max_abs_error: ", max_abs_error) + print(f"{name} max_rel_error: ", max_rel_error) + print(f"{name} mean_rel_error: ", mean_rel_error) + + return max_abs_error, max_rel_error, mean_rel_error + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['Ascend']) +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2]) +def test_asd_fft_accuracy(device_target, mode, ndim): + """ + Feature: Test ASD FFT & IFFT accuracy + Description: Input random tensor, compare the results of ASD FFT and IFFT with numpy results + Expectation: The output tensors should be equal within tolerance + """ + print(f"test_asd_fft_accuracy, ndim: {ndim}") + ms.set_context(device_target=device_target, mode=mode) + a, ar, ai = gen_input() + shape = a.shape + + b = np.fft.fftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0)) + # mindflow DFTn, real and imag are both mindspore tensors + br, bi = DFTn(shape[-ndim:])(ar, ai) + cr, ci = IDFTn(shape[-ndim:])(br, bi) + + # ASD FFT, real and imag are both mindspore tensors + ms_br, ms_bi = asd_fftn(ar, ai, ndim=ndim) + ms_ar, ms_ai = asd_ifftn(ms_br, ms_bi, ndim=ndim) + + # mindflow dft is just used for reference + max_abs_error, max_rel_error, mean_rel_error = cal_error( + "numpy-vs-mindflow-dft", b.real, b.imag, br.asnumpy(), bi.asnumpy()) + max_abs_error, max_rel_error, mean_rel_error = cal_error( + "numpy-vs-ms-dft", b.real, b.imag, ms_br.asnumpy(), ms_bi.asnumpy()) + assert max_abs_error < FP32_RTOL + assert max_rel_error < FP32_RTOL + assert mean_rel_error < FP32_RTOL + # mindflow idft is just used for reference + max_abs_error, max_rel_error, mean_rel_error = cal_error( + "numpy-vs-mindflow-idft", a.real, a.imag, cr.asnumpy(), ci.asnumpy()) + max_abs_error, max_rel_error, mean_rel_error = cal_error( + "numpy-vs-ms-idft", a.real, a.imag, ms_ar.asnumpy(), ms_ai.asnumpy()) + assert max_abs_error < FP32_RTOL + assert max_rel_error < FP32_RTOL + assert mean_rel_error < FP32_RTOL + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['Ascend']) +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2]) +def test_asd_rfft_accuracy(device_target, mode, ndim): + """ + Feature: Test ASD RFFT & IRFFT accuracy + Description: Input random tensor, compare the results of ASD RFFT and IRFFT with numpy results + Expectation: The output tensors should be equal within tolerance + """ + print(f"test_asd_rfft_accuracy, ndim: {ndim}") + ms.set_context(device_target=device_target, mode=mode) + a, ar, _ = gen_input() + shape = a.shape + + b = np.fft.rfftn(a.real, s=a.shape[-ndim:], axes=range(-ndim, 0)) + br, bi = RDFTn(shape[-ndim:])(ar) + cr = IRDFTn(shape[-ndim:])(br, bi) + + ms_br, ms_bi = asd_rfftn(ar, ndim=ndim) + ms_ar = asd_irfftn(ms_br, ms_bi, ndim=ndim) + + ms_ar_n = asd_irfftn(ms_br, ms_bi, n=ar.shape[-1] + 1, ndim=ndim) + np_shape = list(a.shape[-ndim:]) + np_shape[-1] = np_shape[-1] + 1 + np_ar_n = np.fft.irfftn(b, s=np_shape, axes=range(-ndim, 0)) + max_abs_error, max_rel_error, mean_rel_error = cal_error( + "numpy-vs-ms-irdft-n", np_ar_n, None, ms_ar_n.asnumpy(), None) + assert max_abs_error < FP32_RTOL + assert max_rel_error < FP32_RTOL + assert mean_rel_error < FP32_RTOL + + max_abs_error, max_rel_error, mean_rel_error = cal_error( + "numpy-vs-mindflow-rdft", b.real, b.imag, br.asnumpy(), bi.asnumpy()) + max_abs_error, max_rel_error, mean_rel_error = cal_error( + "numpy-vs-ms-rdft", b.real, b.imag, ms_br.asnumpy(), ms_bi.asnumpy()) + assert max_abs_error < FP32_RTOL + assert max_rel_error < FP32_RTOL + assert mean_rel_error < FP32_RTOL + max_abs_error, max_rel_error, mean_rel_error = cal_error( + "numpy-vs-mindflow-irdft", a.real, a.imag, cr.asnumpy(), None) + max_abs_error, max_rel_error, mean_rel_error = cal_error("numpy-vs-ms-irdft", a.real, a.imag, ms_ar.asnumpy(), None) + assert max_abs_error < FP32_RTOL + assert max_rel_error < FP32_RTOL + assert mean_rel_error < FP32_RTOL + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['Ascend']) +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2]) +@pytest.mark.parametrize('cell', ["c2c_fwd", "c2c_inv"]) +# pylint: disable=unused-variable +def test_asd_fft_grad_accuracy(device_target, mode, ndim, cell): + """ + Feature: Test ASD FFT & IFFT grad accuracy + Description: Input random tensor, compare the results of ASD FFT and IFFT with numpy results + Expectation: The output tensors should be equal within tolerance + """ + print(f"test_asd_fft_grad_accuracy, ndim: {ndim}, cell: {cell}") + ms.set_context(device_target=device_target, mode=mode) + a, ar, ai = gen_input() + shape = a.shape + scale = np.prod(a.shape[-ndim:]) + grad_fn = ms.value_and_grad(forwad_fn_c2c, grad_position=(0, 1)) + + if cell == "c2c_fwd": + asd_result, (asd_ar_g, asd_ai_g) = grad_fn(ar, ai, ndim, asd_fftn) + ms_result, (ms_ar_g, ms_ai_g) = grad_fn(ar, ai, None, DFTn(shape[-ndim:])) + np_result = np.fft.fftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0)) + c = loss_grad_np_c(np_result) + np_grad = scale * np.fft.ifftn(c, s=a.shape[-ndim:], axes=range(-ndim, 0)) + elif cell == "c2c_inv": + asd_result, (asd_ar_g, asd_ai_g) = grad_fn(ar, ai, ndim, asd_ifftn) + ms_result, (ms_ar_g, ms_ai_g) = grad_fn(ar, ai, None, IDFTn(shape[-ndim:])) + np_result = np.fft.ifftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0)) + c = loss_grad_np_c(np_result) + np_grad = (1.0 / scale) * np.fft.fftn(c, s=a.shape[-ndim:], axes=range(-ndim, 0)) + else: + raise ValueError(f"fft: Unsupported cell: {cell}, only support c2c_fwd and c2c_inv") + + abs_error, rel_error, mean_error = cal_error( + "fft-numpy-vs-mindflow-backward", np_grad.real, np_grad.imag, ms_ar_g.asnumpy(), ms_ai_g.asnumpy()) + abs_error, rel_error, mean_error = cal_error( + "fft-numpy-vs-asdfft-backward", np_grad.real, np_grad.imag, asd_ar_g.asnumpy(), asd_ai_g.asnumpy()) + assert abs_error < FP32_RTOL + assert rel_error < FP32_RTOL + assert mean_error < FP32_RTOL + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['Ascend']) +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2]) +@pytest.mark.parametrize('cell', ["r2c_fwd", "c2r_inv"]) +# pylint: disable=unused-variable +def test_asd_rfft_grad_accuracy(device_target, mode, ndim, cell): + """ + Feature: Test ASD RFFT & IRFFT grad accuracy + Description: Input random tensor, compare the results of ASD RFFT and IRFFT with numpy results + Expectation: The output tensors should be equal within tolerance + """ + print(f"test_asd_rfft_grad_accuracy, ndim: {ndim}, cell: {cell}") + ms.set_context(device_target=device_target, mode=mode) + a, ar, ai = gen_input() + shape = a.shape + scale = np.prod(a.shape[-ndim:]) + grad_fn_r2c = ms.value_and_grad(forwad_fn_r2c, grad_position=(0)) + grad_fn_c2r = ms.value_and_grad(forwad_fn_c2r, grad_position=(0, 1)) + + n = shape[-1] + m = n // 2 + 1 + k = (n+1) // 2 + alf = np.ones(m) + alf[1:k] += 1 + + + if cell == "r2c_fwd": + asd_result, asd_ar_g = grad_fn_r2c(ar, ndim, asd_rfftn) + ms_result, ms_ar_g = grad_fn_r2c(ar, None, RDFTn(shape[-ndim:])) + np_result = np.fft.rfftn(a.real, s=a.shape[-ndim:], axes=range(-ndim, 0)) + c = loss_grad_np_c(np_result) / alf + np_grad = scale * np.fft.irfftn(c, s=a.shape[-ndim:], axes=range(-ndim, 0)) + abs_error, rel_error, mean_error = cal_error( + "rfft-numpy-vs-mindflow-backward", np_grad, None, ms_ar_g.asnumpy(), None) + abs_error, rel_error, mean_error = cal_error( + "rfft-numpy-vs-asdfft-backward", np_grad, None, asd_ar_g.asnumpy(), None) + + elif cell == "c2r_inv": + ar_input, ai_input = asd_rfftn(ar, ndim=ndim) + asd_result, (asd_ar_g, asd_ai_g) = grad_fn_c2r(ar_input, ai_input, ndim, asd_irfftn) + # mindflow IRDFT + ms_result, (ms_ar_g, ms_ai_g) = grad_fn_c2r(ar_input, ai_input, None, IRDFTn(ar.shape[-ndim:])) + make_complex = ops.Complex() + ms_c = make_complex(ar_input, ai_input).asnumpy() + np_result = np.fft.irfftn(ms_c, s=a.shape[-ndim:], axes=range(-ndim, 0)) + c = loss_grad_np_r(np_result) + n = (ms_c.shape[-1] - 1) * 2 + m = n // 2 + 1 + k = (n+1) // 2 + alf = np.ones(m) + alf[1:k] += 1 + np_grad = (1.0 / scale) * alf * np.fft.rfftn(c, s=a.shape[-ndim:], axes=range(-ndim, 0)) + abs_error, rel_error, mean_error = cal_error( + "rfft-numpy-vs-mindflow-backward", np_grad.real, np_grad.imag, ms_ar_g.asnumpy(), ms_ai_g.asnumpy()) + abs_error, rel_error, mean_error = cal_error( + "rfft-numpy-vs-asdfft-backward", np_grad.real, np_grad.imag, asd_ar_g.asnumpy(), asd_ai_g.asnumpy()) + assert abs_error < FP32_RTOL + assert rel_error < FP32_RTOL + assert mean_error < FP32_RTOL + else: + raise ValueError(f"rfft: Unsupported cell: {cell}, only support r2c_fwd and c2r_inv") + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['Ascend']) +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2]) +@pytest.mark.parametrize('alg', ["DFT", "IDFT", "RDFT", "IRDFT", "FFT", "IFFT", "RFFT", "IRFFT"]) +@pytest.mark.parametrize('dshape', [(20, 512, 512), (20, 1024, 1024), (20, 2048, 2048)]) +def test_fft_speed(device_target, mode, ndim, alg, dshape): + """ + Feature: Test ASD FFT & RFFT speed + Description: Input random tensor, compare the results of ASD FFT and RFFT with numpy results + Expectation: The output tensors should be equal within tolerance + """ + print(f"test_fft_speed, ndim: {ndim}, alg: {alg}, dshape: {dshape}") + ms.set_context(device_target=device_target, mode=mode) + + experimental_config = ms.profiler._ExperimentalConfig( # pylint: disable=protected-access + profiler_level=ProfilerLevel.Level1, + aic_metrics=AicoreMetrics.AiCoreNone, + l2_cache=False, + mstx=False, + data_simplification=False, + export_type=[ExportType.Text]) + a, ar, ai = gen_input(shape=dshape) + shape = a.shape + br, bi = DFTn(shape[-ndim:])(ar, ai) + + prof_file_path = f"./data/fft_speed/{alg}_{ndim}_{dshape[0]}_{dshape[1]}_{dshape[2]}" + + with ms.profiler.profile(activities=[ProfilerActivity.CPU, ProfilerActivity.NPU], + schedule=ms.profiler.schedule(wait=0, warmup=4, active=4, repeat=1, skip_first=0), + on_trace_ready=ms.profiler.tensorboard_trace_handler(prof_file_path), + profile_memory=True, + with_stack=True, + record_shapes=True, + experimental_config=experimental_config) as prof: + for _ in range(10): + if alg == "DFT": + br, bi = DFTn(shape[-ndim:])(ar, ai) + elif alg == "IDFT": + br, bi = IDFTn(shape[-ndim:])(br, bi) + elif alg == "RDFT": + br = RDFTn(shape[-ndim:])(ar) + elif alg == "IRDFT": + ar = IRDFTn(shape[-ndim:])(br, bi) + + elif alg == "FFT": + br, bi = asd_fftn(ar, ai, ndim=ndim) + elif alg == "IFFT": + ar, ai = asd_ifftn(br, bi, ndim=ndim) + elif alg == "RFFT": + br, bi = asd_rfftn(ar, ndim=ndim) + elif alg == "IRFFT": + ar = asd_irfftn(br, bi, ndim=ndim) + + prof.step() + + +if __name__ == "__main__": + test_dft_accuracy(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=1) + test_dft_accuracy(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=2) + test_rdft_accuracy(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=1) + test_rdft_accuracy(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=2) + test_dft_accuracy_with_grad(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=1) + test_dft_accuracy_with_grad(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=2) + test_rdft_accuracy_with_grad(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=1) + test_rdft_accuracy_with_grad(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=2)