diff --git a/MindFlow/mindflow/__init__.py b/MindFlow/mindflow/__init__.py index 1721724bf7e12cb03a88b256e1a66d0306b36574..7186e5af0b08968f002e5dc30ca5cc3de2cb1c48 100644 --- a/MindFlow/mindflow/__init__.py +++ b/MindFlow/mindflow/__init__.py @@ -67,6 +67,7 @@ from .pde import * from .cell import * from .cfd import * from .utils import * +from .fft import * __all__ = [] __all__.extend(data.__all__) @@ -76,3 +77,4 @@ __all__.extend(pde.__all__) __all__.extend(cell.__all__) __all__.extend(cfd.__all__) __all__.extend(utils.__all__) +__all__.extend(fft.__all__) \ No newline at end of file diff --git a/MindFlow/mindflow/fft/__init__.py b/MindFlow/mindflow/fft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf39dd67aa55447256a5940fbf5b232fd9acee4d --- /dev/null +++ b/MindFlow/mindflow/fft/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" + +from .asd_fft_custom_op import * + +__all__ = ["set_cache_size", "asd_fftn", "asd_ifftn", "asd_rfftn", "asd_irfftn", + "asd_fft", "asd_ifft", "asd_rfft", "asd_irfft", "asd_fft2", "asd_ifft2", "asd_rfft2", "asd_irfft2", + "ASD_FFT", "ASD_IFFT", "ASD_RFFT", "ASD_IRFFT", "ASD_FFT2D", "ASD_IFFT2D", "ASD_RFFT2D", "ASD_IRFFT2D"] \ No newline at end of file diff --git a/MindFlow/mindflow/fft/asd_fft_custom_op.py b/MindFlow/mindflow/fft/asd_fft_custom_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b61640de05a2f831ca61bb4737e6445922f2412e --- /dev/null +++ b/MindFlow/mindflow/fft/asd_fft_custom_op.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python3 +""" +NPU FFT 操作的 Python 封装 +提供易于使用的接口来调用 NPU 上的 FFT 功能 +""" +import os +import mindspore as ms +from mindspore import nn, ops, dtype as mstype, mint, jit +from mindspore.ops import DataType, CustomRegOp, CustomOpBuilder + +cur_asd_fft_op_source_file = os.path.join(os.path.dirname(__file__), "asd_fft_op_ext.cpp") +asd_fft_op = CustomOpBuilder("asd_fft_op", cur_asd_fft_op_source_file, backend="Ascend", enable_asdsip=True).load() + +class CustomReal(nn.Cell): + def __init__(self): + super(CustomReal, self).__init__() + aclnn_ref_info = CustomRegOp("aclnnReal") \ + .input(0, "x", "required") \ + .output(0, "z", "required") \ + .dtype_format(DataType.C64_Default, DataType.F32_Default) \ + .target("Ascend") \ + .get_op_info() + + self.real = ops.Custom(func="aclnnReal", out_shape=lambda x: x, + out_dtype=mstype.float32, func_type="aot", reg_info=aclnn_ref_info) + @jit(backend="ms_backend") + def construct(self, x): + return self.real(x) + +class CustomComplex(nn.Cell): + def __init__(self, auto_prefix=True, flags=None): + super().__init__(auto_prefix, flags) + aclnn_ref_info = CustomRegOp("aclnnComplex") \ + .input(0, "real", "required") \ + .input(1, "imag", "required") \ + .output(0, "out", "required") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.C64_Default) \ + .target("Ascend") \ + .get_op_info() + + self.complex = ops.Custom(func="aclnnComplex", out_shape=lambda real, imag: real, + out_dtype=mstype.complex64, func_type="aot", reg_info=aclnn_ref_info) + @jit(backend="ms_backend") + def construct(self, real, imag): + return self.complex(real, imag) + + +def set_cache_size(cache_size): + asd_fft_op.asd_set_cache_size(cache_size) + +make_complex = CustomComplex() +_real_op = CustomReal() +_neg1j = ms.tensor(0-1j, dtype=mstype.complex64) + +def _get_real(x): + return _real_op(x) + +def _get_imag(x): + return _real_op(ops.mul(_neg1j, x)) + +def _get_r2c_alf(xr): + n = xr.shape[-1] + m = n // 2 + 1 + k = (n + 1) // 2 + alf = ops.ones(m, xr.dtype) + # alf[1:k] = 2 but we need scale a/alf, so just set alf[1:k] = 0.5, then do mul(dreal, alf) + alf[1:k] = 0.5 + return alf + +def _get_c2r_alf(xr, scale_factor): + n = (xr.shape[-1] - 1) * 2 + m = n // 2 + 1 + k = (n + 1) // 2 + alf = ops.ones(m, xr.dtype) + # for 1->k set alf = 2.0, others set alf = 1.0 + alf[1:k] = 2.0 + alf = mint.mul(alf, scale_factor) + return alf + + + +# C2C +class ASD_FFT(nn.Cell): + def __init__(self): + super(ASD_FFT, self).__init__() + self.asd_fft_op = asd_fft_op.asd_fft_1d + self.scale_factor = None + + def get_fft_size(self, xr): + return xr.shape[1] + + + def construct(self, xr, xi): + return self.forward(xr, xi) + + def forward(self, xr, xi=None): + if xr.dtype != mstype.float32 or (xi is not None and xi.dtype != mstype.float32): + raise ValueError("ASD_FFT Input tensor must be float32") + org_shape = list(xr.shape) + # 统一变成两维 + if len(org_shape) == 1: + xr = ops.expand_dims(xr, 0) + xi = ops.expand_dims(xi, 0) if xi is not None else None + elif len(org_shape) > 2: + xr = ops.reshape(xr, (-1, org_shape[-1])) + xi = ops.reshape(xi, (-1, org_shape[-1])) if xi is not None else None + + fft_size = self.get_fft_size(xr) + x = make_complex(xr, xi) if xi is not None else xr + output = self.asd_fft_op(x, xr.shape[0], fft_size) + + # 如果output不是复数,则直接返回 + if not ops.is_complex(output): + org_shape[-1] = output.shape[-1] + if self.scale_factor is not None: + output = mint.mul(output, self.scale_factor) + return ops.reshape(output, tuple(org_shape)) + + real = _get_real(output) + imag = _get_imag(output) + # 还原回原来的shape,但是最后一维是会变化的,如: + # r2c 会变为 real.shape[-1]//2+1,c2r 会变为 (real.shape[-1] -1) * 2 + org_shape[-1] = real.shape[-1] + org_shape = tuple(org_shape) + if len(org_shape) == 1 or len(org_shape) > 2: + real = ops.reshape(real, org_shape) + imag = ops.reshape(imag, org_shape) + + if self.scale_factor is not None: + real = mint.mul(real, self.scale_factor) + imag = mint.mul(imag, self.scale_factor) + + return real, imag + + def bprop(self, xr, xi, out, dout): + dreal, dimag = dout + dxr, dxi = asd_ifft(dreal, dimag) + dxr = mint.mul(dxr, xr.shape[-1]) + dxi = mint.mul(dxi, xr.shape[-1]) + return dxr, dxi + +# C2C +class ASD_IFFT(ASD_FFT): + def __init__(self): + super(ASD_IFFT, self).__init__() + self.asd_fft_op = asd_fft_op.asd_ifft_1d + + def get_fft_size(self, xr): + fft_size = super(ASD_IFFT, self).get_fft_size(xr) + self.scale_factor = 1.0 / fft_size + return fft_size + + def bprop(self, xr, xi, out, dout): + dreal, dimag = dout + dxr, dxi = asd_fft(dreal, dimag) + dxr = mint.mul(dxr, self.scale_factor) + dxi = mint.mul(dxi, self.scale_factor) + return dxr, dxi + +# R2C +class ASD_RFFT(ASD_FFT): + def __init__(self): + super(ASD_RFFT, self).__init__() + self.asd_fft_op = asd_fft_op.asd_rfft_1d + + def get_fft_size(self, xr): + return xr.shape[1] + + + def construct(self, xr): + return self.forward(xr) + + def bprop(self, xr, out, dout): + dreal, dimag = dout + alf = _get_r2c_alf(xr) + dreal = mint.mul(dreal, alf) + dimag = mint.mul(dimag, alf) + dxr = asd_irfftn(dreal, dimag, ndim=1, n=xr.shape[-1]) + dxr = mint.mul(dxr, xr.shape[-1]) + return dxr + +# C2R +class ASD_IRFFT(ASD_FFT): + def __init__(self): + super(ASD_IRFFT, self).__init__() + self.asd_fft_op = asd_fft_op.asd_irfft_1d + self.n = None + + def get_fft_size(self, xr): + if self.n is not None and (self.n // 2 + 1) == xr.shape[-1]: + fft_size = self.n + else: + fft_size = (xr.shape[1] - 1) * 2 + self.scale_factor = 1.0 / fft_size + return fft_size + + def set_n(self, n): + self.n = n + + def bprop(self, xr, xi, out, dout): + dreal = dout + alf = _get_c2r_alf(xr, self.scale_factor) + dxr, dxi = asd_rfft(dreal) + dxr = mint.mul(dxr, alf) + dxi = mint.mul(dxi, alf) + + return dxr, dxi + +# C2C forward +class ASD_FFT2D(nn.Cell): + def __init__(self): + super(ASD_FFT2D, self).__init__() + self.asd_fft_op = asd_fft_op.asd_fft_2d + self.scale_factor = None + + def get_fft_size(self, xr): + return xr.shape[0], xr.shape[1], xr.shape[2] + + def construct(self, xr, xi): + return self.forward(xr, xi) + + def forward(self, xr, xi=None): + if xr.dtype != mstype.float32 or (xi is not None and xi.dtype != mstype.float32): + raise ValueError("ASD_FFT Input tensor must be float32") + + org_shape = list(xr.shape) + if len(org_shape) < 2: + raise ValueError("2D FFT Input tensor must have at least 2 dimensions") + + # 统一变成三维 + if len(org_shape) == 2: + xr = ops.expand_dims(xr, 0) + xi = ops.expand_dims(xi, 0) if xi is not None else None + elif len(org_shape) > 3: + xr = ops.reshape(xr, (-1, org_shape[-2], org_shape[-1])) + xi = ops.reshape(xi, (-1, org_shape[-2], org_shape[-1])) if xi is not None else None + + batch_size, x_size, y_size = self.get_fft_size(xr) + x = make_complex(xr, xi) if xi is not None else xr + output = self.asd_fft_op(x, batch_size, x_size, y_size) + + # 如果output不是复数,则直接返回 + if not ops.is_complex(output): + org_shape[-1] = output.shape[-1] + if self.scale_factor is not None: + output = mint.mul(output, self.scale_factor) + return ops.reshape(output, tuple(org_shape)) + + real = _get_real(output) + imag = _get_imag(output) + # 还原回原来的shape,但是最后一维是会变化的,如: + # r2c 会变为 real.shape[-1]//2+1,c2r 会变为 (real.shape[-1] -1) * 2 + org_shape[-1] = real.shape[-1] + org_shape = tuple(org_shape) + if len(org_shape) == 2 or len(org_shape) > 3: + real = ops.reshape(real, org_shape) + imag = ops.reshape(imag, org_shape) + + if self.scale_factor is not None: + real = mint.mul(real, self.scale_factor) + imag = mint.mul(imag, self.scale_factor) + + return real, imag + + def bprop(self, xr, xi, out, dout): + dreal, dimag = dout + dxr, dxi = asd_ifft2(dreal, dimag) + N = xr.shape[-1] * xr.shape[-2] + dxr = mint.mul(dxr, N) + dxi = mint.mul(dxi, N) + return dxr, dxi + +# C2C inverse +class ASD_IFFT2D(ASD_FFT2D): + def __init__(self): + super(ASD_IFFT2D, self).__init__() + self.asd_fft_op = asd_fft_op.asd_ifft_2d + + def get_fft_size(self, xr): + batch_size, x_size, y_size = super(ASD_IFFT2D, self).get_fft_size(xr) + self.scale_factor = 1.0 / (x_size * y_size) + return batch_size, x_size, y_size + + def bprop(self, xr, xi, out, dout): + dreal, dimag = dout + dxr, dxi = asd_fft2(dreal, dimag) + dxr = mint.mul(dxr, self.scale_factor) + dxi = mint.mul(dxi, self.scale_factor) + return dxr, dxi + +# R2C forward +class ASD_RFFT2D(ASD_FFT2D): + def __init__(self): + super(ASD_RFFT2D, self).__init__() + self.asd_fft_op = asd_fft_op.asd_rfft_2d + + def get_fft_size(self, xr): + return xr.shape[0], xr.shape[1], xr.shape[2] + + def construct(self, xr): + return self.forward(xr) + + def bprop(self, xr, out, dout): + dreal, dimag = dout + alf = _get_r2c_alf(xr) + dreal = mint.mul(dreal, alf) + dimag = mint.mul(dimag, alf) + dxr = asd_irfftn(dreal, dimag, ndim=2, n=xr.shape[-1]) + N = xr.shape[-1] * xr.shape[-2] * 1.0 + dxr = mint.mul(dxr, N) + return dxr + +# C2R inverse +class ASD_IRFFT2D(ASD_FFT2D): + def __init__(self): + super(ASD_IRFFT2D, self).__init__() + self.asd_fft_op = asd_fft_op.asd_irfft_2d + self.n = None + + def get_fft_size(self, xr): + batch_size = xr.shape[0] + x_size = xr.shape[1] + y_size = xr.shape[2] + if self.n is not None and (self.n // 2 + 1) == y_size: + output_last_dim = self.n + else: + output_last_dim = (y_size - 1) * 2 + self.scale_factor = 1.0 / (x_size * output_last_dim) + return batch_size, x_size, output_last_dim + + def set_n(self, n): + self.n = n + + def bprop(self, xr, xi, out, dout): + dreal = dout + dxr, dxi = asd_rfft2(dreal) + alf = _get_c2r_alf(xr, self.scale_factor) + dxr = mint.mul(dxr, alf) + dxi = mint.mul(dxi, alf) + return dxr, dxi + +asd_fft = ASD_FFT() +asd_ifft = ASD_IFFT() +asd_rfft = ASD_RFFT() +asd_irfft = ASD_IRFFT() +asd_fft2 = ASD_FFT2D() +asd_ifft2 = ASD_IFFT2D() +asd_rfft2 = ASD_RFFT2D() +asd_irfft2 = ASD_IRFFT2D() + +def asd_fftn(xr, xi, ndim=1): + if ndim == 1: + return asd_fft(xr, xi) + elif ndim == 2: + return asd_fft2(xr, xi) + else: + raise ValueError(f"asd_fftn Unsupported dimension: {ndim}, only support 1D and 2D") + +def asd_ifftn(xr, xi, ndim=1): + if ndim == 1: + return asd_ifft(xr, xi) + elif ndim == 2: + return asd_ifft2(xr, xi) + else: + raise ValueError(f"asd_ifftn Unsupported dimension: {ndim}, only support 1D and 2D") + +def asd_rfftn(xr, ndim=1): + if ndim == 1: + return asd_rfft(xr) + elif ndim == 2: + return asd_rfft2(xr) + else: + raise ValueError(f"asd_rfftn Unsupported dimension: {ndim}, only support 1D and 2D") + +def asd_irfftn(xr, xi, n=None, ndim=1): + if ndim == 1: + instance = ASD_IRFFT() + instance.set_n(n) + return instance(xr, xi) + elif ndim == 2: + instance = ASD_IRFFT2D() + instance.set_n(n) + return instance(xr, xi) + else: + raise ValueError(f"asd_irfftn Unsupported dimension: {ndim}, only support 1D and 2D") diff --git a/MindFlow/mindflow/fft/asd_fft_op_ext.cpp b/MindFlow/mindflow/fft/asd_fft_op_ext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..634f94660ca90066d8ce94dbbebacedeb1f94664 --- /dev/null +++ b/MindFlow/mindflow/fft/asd_fft_op_ext.cpp @@ -0,0 +1,218 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "ms_extension/api.h" +using std::string; +using ms::pynative::FFTParam; +using ms::pynative::asdFftDirection; +using ms::pynative::asdFft1dDimType; +using ms::pynative::asdFftType; + +ms::Tensor GetResultTensor(const ms::Tensor &t, const FFTParam ¶m) { + auto type_id = mindspore::TypeId::kNumberTypeComplex64; + if (param.fftType == asdFftType::ASCEND_FFT_C2R) { + type_id = mindspore::TypeId::kNumberTypeFloat32; + } + // for fft and ifft, the output shape is equal to input shape + ShapeVector out_shape(t.shape()); + + // for rfft, the output shape is (batch_size, x_size, (y_size // 2) + 1) + if (param.fftType == asdFftType::ASCEND_FFT_R2C && param.direction == asdFftDirection::ASCEND_FFT_FORWARD) { + out_shape.back() = (out_shape.back() / 2) + 1; + } + // for irfft, the output shape is (batch_size, x_size, fftYSize or fftXSize1d) + if (param.fftType == asdFftType::ASCEND_FFT_C2R && param.direction == asdFftDirection::ASCEND_FFT_BACKWARD) { + out_shape.back() = param.fftYSize == 0 ? param.fftXSize : param.fftYSize; + } + return ms::Tensor(type_id, out_shape); +} + +ms::Tensor exec_asdsip_fft(const string &op_name, const FFTParam ¶m, const ms::Tensor &input) { + MS_LOG(INFO) << "Run device task LaunchAsdSipFFT start, op_name: " << op_name << ", param.fftX: " << param.fftXSize + << ", param.fftY: " << param.fftYSize << ", param.fftType: " << param.fftType << ", param.direction: " + << param.direction << ", param.batchSize: " << param.batchSize << ", param.dimType: " << param.dimType; + auto output = GetResultTensor(input, param); + ms::pynative::RunAsdSipFFTOp(op_name, param, input, output); + MS_LOG(INFO) << "Run device task LaunchAsdSipFFT end"; + + return output; +} + +auto pyboost_npu_set_cache_size(int64_t cache_size) { + MS_LOG(INFO) << "Set cache size for NPU FFT, cache_size: " << cache_size; + ms::pynative::AsdSipFFTOpRunner::SetCacheSize(cache_size); + MS_LOG(INFO) << "Set cache size for NPU FFT end"; +} + +// 1D FFT forward +auto pyboost_npu_fft_1d(const ms::Tensor &input, int64_t batch_size, int64_t x_size) { + FFTParam param; + param.fftXSize = x_size; + param.fftYSize = 0; + param.fftType = asdFftType::ASCEND_FFT_C2C; + param.direction = asdFftDirection::ASCEND_FFT_FORWARD; + param.batchSize = batch_size; + param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; + const string op_name = "asdFftExecC2C"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 1D FFT inverse +auto pyboost_npu_ifft_1d(const ms::Tensor &input, int64_t batch_size, int64_t x_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = 0; + param.fftType = asdFftType::ASCEND_FFT_C2C; + param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; + param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; + const string op_name = "asdFftExecC2C"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 1D RFFT forward +auto pyboost_npu_rfft_1d(const ms::Tensor &input, int64_t batch_size, int64_t x_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = 0; + param.fftType = asdFftType::ASCEND_FFT_R2C; + param.direction = asdFftDirection::ASCEND_FFT_FORWARD; + param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; + const string op_name = "asdFftExecR2C"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 1D IRFFT inverse +auto pyboost_npu_irfft_1d(const ms::Tensor &input, int64_t batch_size, int64_t x_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = 0; + param.fftType = asdFftType::ASCEND_FFT_C2R; + param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; + param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; + const string op_name = "asdFftExecC2R"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 2D FFT forward +auto pyboost_npu_fft_2d(const ms::Tensor &input, int64_t batch_size, int64_t x_size, int64_t y_size) { + ms::pynative::FFTParam param; + param.fftXSize = x_size; + param.fftYSize = y_size; + param.fftType = asdFftType::ASCEND_FFT_C2C; + param.direction = asdFftDirection::ASCEND_FFT_FORWARD; + param.batchSize = batch_size; + const string op_name = "asdFftExecC2C"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 2D IFFT inverse +auto pyboost_npu_ifft_2d(const ms::Tensor &input, int64_t batch_size, int64_t x_size, int64_t y_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = y_size; + param.fftType = asdFftType::ASCEND_FFT_C2C; + param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; + const string op_name = "asdFftExecC2C"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 2D RFFT forward +auto pyboost_npu_rfft_2d(const ms::Tensor &input, int64_t batch_size, int64_t x_size, int64_t y_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = y_size; + param.fftType = asdFftType::ASCEND_FFT_R2C; + param.direction = asdFftDirection::ASCEND_FFT_FORWARD; + const string op_name = "asdFftExecR2C"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// 2D IRFFT inverse +auto pyboost_npu_irfft_2d(const ms::Tensor &input, int64_t batch_size, int64_t x_size, int64_t y_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = y_size; + param.fftType = asdFftType::ASCEND_FFT_C2R; + param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; + const string op_name = "asdFftExecC2R"; + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); +} + +// Python 绑定模块 +PYBIND11_MODULE(MS_EXTENSION_NAME, m) { + m.doc() = "NPU FFT operations for MindSpore"; + + m.def("asd_set_cache_size", &pyboost_npu_set_cache_size, "Set cache size for NPU FFT", + pybind11::arg("cache_size")); + + // 1D FFT + m.def("asd_fft_1d", &pyboost_npu_fft_1d, "1D FFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size")); + + // 1D IFFT + m.def("asd_ifft_1d", &pyboost_npu_ifft_1d, "1D IFFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size")); + + // 1D RFFT + m.def("asd_rfft_1d", &pyboost_npu_rfft_1d, "1D RFFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size")); + + // 1D IRFFT + m.def("asd_irfft_1d", &pyboost_npu_irfft_1d, "1D IRFFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size")); + + // 2D FFT + m.def("asd_fft_2d", &pyboost_npu_fft_2d, "2D FFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size"), + pybind11::arg("y_size")); + + // 2D IFFT + m.def("asd_ifft_2d", &pyboost_npu_ifft_2d, "2D IFFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size"), + pybind11::arg("y_size")); + + // 2D RFFT + m.def("asd_rfft_2d", &pyboost_npu_rfft_2d, "2D RFFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size"), + pybind11::arg("y_size")); + + // 2D IRFFT + m.def("asd_irfft_2d", &pyboost_npu_irfft_2d, "2D IRFFT on NPU", + pybind11::arg("input"), + pybind11::arg("batch_size"), + pybind11::arg("x_size"), + pybind11::arg("y_size")); +} diff --git a/MindFlow/setup.py b/MindFlow/setup.py index 02038fabd95ba48c4cdbcd1499bfe85cfcadd713..0f431253597b408e371b07e10f0a60b4e8f54178 100644 --- a/MindFlow/setup.py +++ b/MindFlow/setup.py @@ -51,6 +51,7 @@ package_data = { 'include/*' 'build_info.txt' ], + 'mindflow.fft': ['*.cpp'], '_c_minddata': ['lib_c_minddata*.so'] } diff --git a/tests/st/mindflow/operators/test_asd_fft.py b/tests/st/mindflow/operators/test_asd_fft.py new file mode 100644 index 0000000000000000000000000000000000000000..327a4a791bf8efbcd23d11a37b33df31e94a4226 --- /dev/null +++ b/tests/st/mindflow/operators/test_asd_fft.py @@ -0,0 +1,367 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Optimizers Test Case""" +import os +import random +import sys +import numpy as np +import mindspore as ms +import pytest +from mindspore import set_seed, mint, ops +from mindflow import DFTn, IDFTn, RDFTn, IRDFTn, asd_fftn, asd_ifftn, asd_rfftn, asd_irfftn +from mindspore.profiler import ProfilerLevel, ProfilerActivity, AicoreMetrics, ExportType + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(PROJECT_ROOT) + +# pylint: disable=wrong-import-position + +# pylint: enable=wrong-import-position + +set_seed(0) +np.random.seed(0) +random.seed(0) +FP32_RTOL = 1e-7 + +def loss_func_c(yr, yi): + return ops.sum(yr * yr + 2 * yi * yi) + +def loss_grad_np_c(y): + return 2 * y.real + 4j * y.imag + +def loss_func_r(yr): + return ops.sum(yr * yr * yr) + +def loss_grad_np_r(y): + return 3 * y * y + +def forwad_fn_c2c(xr, xi, dim, dft_cell): + if dim is not None: + br, bi = dft_cell(xr, xi, ndim=dim) + else: + br, bi = dft_cell(xr, xi) + return loss_func_c(br, bi) + +def forwad_fn_r2c(xr, dim, dft_cell): + if dim is not None: + br, bi = dft_cell(xr, ndim=dim) + else: + br, bi = dft_cell(xr) + return loss_func_c(br, bi) + +def forwad_fn_c2r(xr, xi, dim, dft_cell): + if dim is not None: + br = dft_cell(xr, xi, ndim=dim) + else: + br = dft_cell(xr, xi) + return loss_func_r(br) + +def gen_input(shape=(2, 16, 16), rand_test=True): + ''' Generate random or deterministic tensor for input of the tests + ''' + a = np.random.rand(*shape) + 1j * np.random.rand(*shape) + if not rand_test: + a = sum([np.arange(n).reshape([n] + [1] * j) for j, n in enumerate(shape[::-1])]) + 1j * \ + sum([np.arange(n).reshape([n] + [1] * j) for j, n in enumerate(shape[::-1])]) + ar, ai = (ms.Tensor(a.real, dtype=ms.float32), ms.Tensor(a.imag, dtype=ms.float32)) + return a, ar, ai + +def cal_error(name, ar, ai, br, bi): + ''' + ar, ai, br, bi are all numpy arrays, calculate the max absolute error, max relative error, and mean relative error + ''' + print(f"{name} ar.shape: ", ar.shape) + print(f"{name} br.shape: ", br.shape) + if ai is not None and bi is not None: + print(f"{name} ai.shape: ", ai.shape) + print(f"{name} bi.shape: ", bi.shape) + abs_error_real = np.abs(ar - br) + rel_error_real = abs_error_real / (np.abs(ar) + 1e-10) + if ai is not None and bi is not None: + abs_error_imag = np.abs(ai - bi) + rel_error_imag = abs_error_imag / (np.abs(ai) + 1e-10) + max_abs_error = max(np.max(abs_error_real), np.max(abs_error_imag)) + max_rel_error = max(np.max(rel_error_real), np.max(rel_error_imag)) + mean_rel_error = (np.mean(rel_error_real) + np.mean(rel_error_imag)) / 2 + else: + abs_error_imag = None + rel_error_imag = None + max_abs_error = np.max(abs_error_real) + max_rel_error = np.max(rel_error_real) + mean_rel_error = np.mean(rel_error_real) + + print(f"{name} max_abs_error: ", max_abs_error) + print(f"{name} max_rel_error: ", max_rel_error) + print(f"{name} mean_rel_error: ", mean_rel_error) + + return max_abs_error, max_rel_error, mean_rel_error + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['Ascend']) +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2]) +def test_dft_accuracy(device_target, mode, ndim): + """ + Feature: Test DFTn & IDFTn accuracy + Description: Input random tensor, compare the results of DFTn and IDFTn with numpy results + Expectation: The output tensors should be equal within tolerance + """ + print(f"test_dft_accuracy, ndim: {ndim}") + ms.set_context(device_target=device_target, mode=mode) + a, ar, ai = gen_input() + shape = a.shape + + b = np.fft.fftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0)) + # mindflow DFTn, real and imag are both mindspore tensors + br, bi = DFTn(shape[-ndim:])(ar, ai) + cr, ci = IDFTn(shape[-ndim:])(br, bi) + + # ASD FFT, real and imag are both mindspore tensors + ms_br, ms_bi = asd_fftn(ar, ai, ndim=ndim) + ms_ar, ms_ai = asd_ifftn(ms_br, ms_bi, ndim=ndim) + + max_abs_error, max_rel_error, mean_rel_error = cal_error("numpy-vs-mindflow-dft", b.real, b.imag, br.asnumpy(), bi.asnumpy()) + max_abs_error, max_rel_error, mean_rel_error = cal_error("numpy-vs-ms-dft", b.real, b.imag, ms_br.asnumpy(), ms_bi.asnumpy()) + assert max_abs_error < 1e-3 + assert max_rel_error < 1e-3 + assert mean_rel_error < 1e-3 + max_abs_error, max_rel_error, mean_rel_error = cal_error("numpy-vs-mindflow-idft", a.real, a.imag, cr.asnumpy(), ci.asnumpy()) + max_abs_error, max_rel_error, mean_rel_error = cal_error("numpy-vs-ms-idft", a.real, a.imag, ms_ar.asnumpy(), ms_ai.asnumpy()) + assert max_abs_error < 1e-3 + assert max_rel_error < 1e-3 + assert mean_rel_error < 1e-3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['Ascend']) +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2]) +def test_rdft_accuracy(device_target, mode, ndim): + """ + Feature: Test RDFTn & IRDFTn accuracy + Description: Input random tensor, compare the results of RDFTn and IRDFTn with numpy results + Expectation: The output tensors should be equal within tolerance + """ + print(f"test_rdft_accuracy, ndim: {ndim}") + ms.set_context(device_target=device_target, mode=mode) + a, ar, _ = gen_input() + shape = a.shape + + b = np.fft.rfftn(a.real, s=a.shape[-ndim:], axes=range(-ndim, 0)) + br, bi = RDFTn(shape[-ndim:])(ar) + cr = IRDFTn(shape[-ndim:])(br, bi) + + ms_br, ms_bi = asd_rfftn(ar, ndim=ndim) + ms_ar = asd_irfftn(ms_br, ms_bi, ndim=ndim) + + ms_ar_n = asd_irfftn(ms_br, ms_bi, n=ar.shape[-1] + 1, ndim=ndim) + np_shape = list(a.shape[-ndim:]) + np_shape[-1] = np_shape[-1] + 1 + np_ar_n = np.fft.irfftn(b, s=np_shape, axes=range(-ndim, 0)) + max_abs_error, max_rel_error, mean_rel_error = cal_error("numpy-vs-ms-irdft-n", np_ar_n, None, ms_ar_n.asnumpy(), None) + assert max_abs_error < 1e-3 + assert max_rel_error < 1e-3 + assert mean_rel_error < 1e-3 + + max_abs_error, max_rel_error, mean_rel_error = cal_error("numpy-vs-mindflow-rdft", b.real, b.imag, br.asnumpy(), bi.asnumpy()) + max_abs_error, max_rel_error, mean_rel_error = cal_error("numpy-vs-ms-rdft", b.real, b.imag, ms_br.asnumpy(), ms_bi.asnumpy()) + assert max_abs_error < 1e-3 + assert max_rel_error < 1e-3 + assert mean_rel_error < 1e-3 + max_abs_error, max_rel_error, mean_rel_error = cal_error("numpy-vs-mindflow-irdft", a.real, a.imag, cr.asnumpy(), None) + max_abs_error, max_rel_error, mean_rel_error = cal_error("numpy-vs-ms-irdft", a.real, a.imag, ms_ar.asnumpy(), None) + assert max_abs_error < 1e-3 + assert max_rel_error < 1e-3 + assert mean_rel_error < 1e-3 + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['Ascend']) +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2]) +@pytest.mark.parametrize('cell', ["c2c_fwd", "c2c_inv"]) +def test_dft_accuracy_with_grad(device_target, mode, ndim, cell): + """ + Feature: Test DFTn & IDFTn accuracy with grad + Description: Input random tensor, compare the results of DFTn and IDFTn with numpy results + Expectation: The output tensors should be equal within tolerance + """ + print(f"test_dft_accuracy_with_grad, ndim: {ndim}, cell: {cell}") + ms.set_context(device_target=device_target, mode=mode) + a, ar, ai = gen_input() + shape = a.shape + scale = np.prod(a.shape[-ndim:]) + grad_fn = ms.value_and_grad(forwad_fn_c2c, grad_position=(0, 1)) + + if cell == "c2c_fwd": + asd_result, (asd_ar_g, asd_ai_g) = grad_fn(ar, ai, ndim, asd_fftn) + ms_result, (ms_ar_g, ms_ai_g) = grad_fn(ar, ai, None, DFTn(shape[-ndim:])) + np_result = np.fft.fftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0)) + c = loss_grad_np_c(np_result) + np_grad = scale * np.fft.ifftn(c, s=a.shape[-ndim:], axes=range(-ndim, 0)) + elif cell == "c2c_inv": + asd_result, (asd_ar_g, asd_ai_g) = grad_fn(ar, ai, ndim, asd_ifftn) + ms_result, (ms_ar_g, ms_ai_g) = grad_fn(ar, ai, None, IDFTn(shape[-ndim:])) + np_result = np.fft.ifftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0)) + c = loss_grad_np_c(np_result) + np_grad = (1.0 / scale) * np.fft.fftn(c, s=a.shape[-ndim:], axes=range(-ndim, 0)) + else: + raise ValueError(f"fft: Unsupported cell: {cell}, only support c2c_fwd and c2c_inv") + + abs_error, rel_error, mean_error = cal_error("fft-numpy-vs-mindflow-backward", np_grad.real, np_grad.imag, ms_ar_g.asnumpy(), ms_ai_g.asnumpy()) + abs_error, rel_error, mean_error = cal_error("fft-numpy-vs-asdfft-backward", np_grad.real, np_grad.imag, asd_ar_g.asnumpy(), asd_ai_g.asnumpy()) + assert abs_error < 1e-3 + assert rel_error < 1e-3 + assert mean_error < 1e-3 + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['Ascend']) +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2]) +@pytest.mark.parametrize('cell', ["r2c_fwd", "c2r_inv"]) +def test_rdft_accuracy_with_grad(device_target, mode, ndim, cell): + """ + Feature: Test RDFTn & IRDFTn accuracy with grad + Description: Input random tensor, compare the results of RDFTn and IRDFTn with numpy results + Expectation: The output tensors should be equal within tolerance + """ + print(f"test_rdft_accuracy_with_grad, ndim: {ndim}, cell: {cell}") + ms.set_context(device_target=device_target, mode=mode) + a, ar, ai = gen_input() + shape = a.shape + scale = np.prod(a.shape[-ndim:]) + grad_fn_r2c = ms.value_and_grad(forwad_fn_r2c, grad_position=(0)) + grad_fn_c2r = ms.value_and_grad(forwad_fn_c2r, grad_position=(0, 1)) + + n = shape[-1] + m = n // 2 + 1 + k = (n+1) // 2 + alf = np.ones(m) + alf[1:k] += 1 + + + if cell == "r2c_fwd": + asd_result, asd_ar_g = grad_fn_r2c(ar, ndim, asd_rfftn) + ms_result, ms_ar_g = grad_fn_r2c(ar, None, RDFTn(shape[-ndim:])) + np_result = np.fft.rfftn(a.real, s=a.shape[-ndim:], axes=range(-ndim, 0)) + c = loss_grad_np_c(np_result) / alf + np_grad = scale * np.fft.irfftn(c, s=a.shape[-ndim:], axes=range(-ndim, 0)) + abs_error, rel_error, mean_error = cal_error("rfft-numpy-vs-mindflow-backward", + np_grad, None, ms_ar_g.asnumpy(), None) + abs_error, rel_error, mean_error = cal_error("rfft-numpy-vs-asdfft-backward", + np_grad, None, asd_ar_g.asnumpy(), None) + + elif cell == "c2r_inv": + ar_input, ai_input = asd_rfftn(ar, ndim=ndim) + asd_result, (asd_ar_g, asd_ai_g) = grad_fn_c2r(ar_input, ai_input, ndim, asd_irfftn) + # mindflow IRDFT + ms_result, (ms_ar_g, ms_ai_g) = grad_fn_c2r(ar_input, ai_input, None, IRDFTn(ar.shape[-ndim:])) + make_complex = ops.Complex() + ms_c = make_complex(ar_input, ai_input).asnumpy() + np_result = np.fft.irfftn(ms_c, s=a.shape[-ndim:], axes=range(-ndim, 0)) + c = loss_grad_np_r(np_result) + n = (ms_c.shape[-1] - 1) * 2 + m = n // 2 + 1 + k = (n+1) // 2 + alf = np.ones(m) + alf[1:k] += 1 + np_grad = (1.0 / scale) * alf * np.fft.rfftn(c, s=a.shape[-ndim:], axes=range(-ndim, 0)) + abs_error, rel_error, mean_error = cal_error("rfft-numpy-vs-mindflow-backward", + np_grad.real, np_grad.imag, ms_ar_g.asnumpy(), ms_ai_g.asnumpy()) + abs_error, rel_error, mean_error = cal_error("rfft-numpy-vs-asdfft-backward", + np_grad.real, np_grad.imag, asd_ar_g.asnumpy(), asd_ai_g.asnumpy()) + assert abs_error < 1e-3 + assert rel_error < 1e-3 + assert mean_error < 1e-3 + else: + raise ValueError(f"rfft: Unsupported cell: {cell}, only support r2c_fwd and c2r_inv") + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['Ascend']) +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2]) +@pytest.mark.parametrize('alg', ["DFT", "IDFT", "RDFT", "IRDFT", "FFT", "IFFT", "RFFT", "IRFFT"]) +@pytest.mark.parametrize('dshape', [(20, 512, 512), (20, 1024, 1024), (20, 2048, 2048)]) +def test_fft_speed(device_target, mode, ndim, alg, dshape): + """ + Feature: Test DFTn & IDFTn accuracy + Description: Input random tensor, compare the results of DFTn and IDFTn with numpy results + Expectation: The output tensors should be equal within tolerance + """ + print(f"test_fft_speed, ndim: {ndim}, alg: {alg}, dshape: {dshape}") + ms.set_context(device_target=device_target, mode=mode) + + experimental_config = ms.profiler._ExperimentalConfig( + profiler_level=ProfilerLevel.Level1, + aic_metrics=AicoreMetrics.AiCoreNone, + l2_cache=False, + mstx=False, + data_simplification=False, + export_type=[ExportType.Text]) + a, ar, ai = gen_input(shape=dshape) + shape = a.shape + br, bi = DFTn(shape[-ndim:])(ar, ai) + + prof_file_path = f"./data/fft_speed/{alg}_{ndim}_{dshape[0]}_{dshape[1]}_{dshape[2]}" + + with ms.profiler.profile(activities=[ProfilerActivity.CPU, ProfilerActivity.NPU], + schedule=ms.profiler.schedule(wait=0, warmup=4, active=4, repeat=1, skip_first=0), + on_trace_ready=ms.profiler.tensorboard_trace_handler(prof_file_path), + profile_memory=True, + with_stack=True, + record_shapes=True, + experimental_config=experimental_config) as prof: + for i in range(10): + if alg == "DFT": + br, bi = DFTn(shape[-ndim:])(ar, ai) + elif alg == "IDFT": + br, bi = IDFTn(shape[-ndim:])(br, bi) + elif alg == "RDFT": + br = RDFTn(shape[-ndim:])(ar) + elif alg == "IRDFT": + ar = IRDFTn(shape[-ndim:])(br, bi) + + elif alg == "FFT": + br, bi = asd_fftn(ar, ai, ndim=ndim) + elif alg == "IFFT": + ar, ai = asd_ifftn(br, bi, ndim=ndim) + elif alg == "RFFT": + br, bi = asd_rfftn(ar, ndim=ndim) + elif alg == "IRFFT": + ar = asd_irfftn(br, bi, ndim=ndim) + + prof.step() + + +if __name__ == "__main__": + test_dft_accuracy(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=1) + test_dft_accuracy(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=2) + test_rdft_accuracy(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=1) + test_rdft_accuracy(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=2) + test_dft_accuracy_with_grad(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=1) + test_dft_accuracy_with_grad(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=2) + test_rdft_accuracy_with_grad(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=1) + test_rdft_accuracy_with_grad(device_target='Ascend', mode=ms.PYNATIVE_MODE, ndim=2)