diff --git a/mindscience/sciops/fft/asd_fft_custom_op.py b/mindscience/sciops/fft/asd_fft_custom_op.py index bf01799d65f8b97d7621c0f21f00ee49f325bdc9..4cc9a9bfbecb7198092b341f2b93308fdb7013d5 100644 --- a/mindscience/sciops/fft/asd_fft_custom_op.py +++ b/mindscience/sciops/fft/asd_fft_custom_op.py @@ -193,7 +193,7 @@ class ASD_FFT(nn.Cell): # pylint: disable=invalid-name optimized for Ascend NPU hardware acceleration. Args: - None + complex_backend (bool): Whether to use complex backend. Default: False. Using real number backend Inputs: - **xr** (Tensor): Real part of input complex tensor with data type float32. @@ -221,16 +221,17 @@ class ASD_FFT(nn.Cell): # pylint: disable=invalid-name >>> print(yi.shape) (1, 4) """ - def __init__(self): + def __init__(self, complex_backend=False): super(ASD_FFT, self).__init__() self.asd_fft_op = _get_asd_fft_op().asd_fft_1d + self.asd_fft_op_sep = _get_asd_fft_op().asd_fft_1d_sep self.make_complex = CustomComplex() self.used_bprop_inputs = [] + self.complex_backend = complex_backend def get_fft_size_and_scale(self, xr): return xr.shape[-1], None - def construct(self, xr, xi): return self.forward(xr, xi) @@ -249,20 +250,31 @@ class ASD_FFT(nn.Cell): # pylint: disable=invalid-name xi = mint.reshape(xi, (batch_size, org_shape[-1])) if xi is not None else None fft_size, scale_factor = self.get_fft_size_and_scale(xr) - x = self.make_complex(xr, xi) if xi is not None else xr - output = self.asd_fft_op(x, xr.shape[0], fft_size) + if not self.complex_backend: + real, imag = self.asd_fft_op_sep(xr, xi, xr.shape[0], fft_size) + if scale_factor is not None: + real.mul_(scale_factor) + imag.mul_(scale_factor) + if org_shape != list(real.shape): + org_shape[-1] = real.shape[-1] + real = mint.reshape(real, tuple(org_shape)) + imag = mint.reshape(imag, tuple(org_shape)) + return real, imag + else: + x = self.make_complex(xr, xi) if xi is not None else xr + output = self.asd_fft_op(x, xr.shape[0], fft_size) - if scale_factor is not None: - output.mul_(scale_factor) + if scale_factor is not None: + output.mul_(scale_factor) - if org_shape != list(output.shape): - org_shape[-1] = output.shape[-1] - output = mint.reshape(output, tuple(org_shape)) - # If output is not complex, return directly - if not ops.is_complex(output): - return output + if org_shape != list(output.shape): + org_shape[-1] = output.shape[-1] + output = mint.reshape(output, tuple(org_shape)) + # If output is not complex, return directly + if not ops.is_complex(output): + return output - return _get_real(output), _get_imag(output) + return _get_real(output), _get_imag(output) def bprop(self, xr, xi, out, dout): # pylint: disable=unused-argument dreal, dimag = dout @@ -278,7 +290,8 @@ class ASD_IFFT(ASD_FFT): # pylint: disable=invalid-name optimized for Ascend NPU hardware acceleration. Args: - None + complex_backend (bool): Whether to use complex backend. Default: False. + Using real number backend by default. Inputs: - **xr** (Tensor): Real part of input complex tensor with data type float32. @@ -306,10 +319,12 @@ class ASD_IFFT(ASD_FFT): # pylint: disable=invalid-name >>> print(yi.shape) (1, 4) """ - def __init__(self): - super(ASD_IFFT, self).__init__() + def __init__(self, complex_backend=False): + super(ASD_IFFT, self).__init__(complex_backend) self.asd_fft_op = _get_asd_fft_op().asd_ifft_1d + self.asd_fft_op_sep = _get_asd_fft_op().asd_ifft_1d_sep self.used_bprop_inputs = [] + self.complex_backend = complex_backend def get_fft_size_and_scale(self, xr): fft_size = xr.shape[-1] @@ -331,7 +346,8 @@ class ASD_RFFT(ASD_FFT): # pylint: disable=invalid-name optimized for Ascend NPU hardware acceleration. Args: - None + complex_backend (bool): Whether to use complex backend. Default: False. + Using real number backend by default. Inputs: - **xr** (Tensor): Input real tensor with data type float32. @@ -699,20 +715,25 @@ _asd_fft2_instance = None _asd_ifft2_instance = None _asd_rfft2_instance = None _asd_irfft2_instance = None +_complex_backend = False + +def set_backend(complex_backend): + global _complex_backend + _complex_backend = complex_backend # 延迟初始化的FFT操作符函数 def asd_fft(*args, **kwargs): """延迟初始化的ASD_FFT调用""" global _asd_fft_instance if _asd_fft_instance is None: - _asd_fft_instance = ASD_FFT() + _asd_fft_instance = ASD_FFT(complex_backend=_complex_backend) return _asd_fft_instance(*args, **kwargs) def asd_ifft(*args, **kwargs): """延迟初始化的ASD_IFFT调用""" global _asd_ifft_instance if _asd_ifft_instance is None: - _asd_ifft_instance = ASD_IFFT() + _asd_ifft_instance = ASD_IFFT(complex_backend=_complex_backend) return _asd_ifft_instance(*args, **kwargs) def asd_rfft(*args, **kwargs): diff --git a/mindscience/sciops/fft/asd_fft_op_ext.cpp b/mindscience/sciops/fft/asd_fft_op_ext.cpp index d615c4ede774cbf2cf1add3401bdc6d99527e95f..aed3ee49d7ca94923472f58ccb5abb07ef254313 100644 --- a/mindscience/sciops/fft/asd_fft_op_ext.cpp +++ b/mindscience/sciops/fft/asd_fft_op_ext.cpp @@ -22,11 +22,7 @@ using ms::pynative::asdFftDirection; using ms::pynative::asdFft1dDimType; using ms::pynative::asdFftType; -ms::Tensor GetResultTensor(const ms::Tensor &t, const FFTParam ¶m) { - auto type_id = mindspore::TypeId::kNumberTypeComplex64; - if (param.fftType == asdFftType::ASCEND_FFT_C2R) { - type_id = mindspore::TypeId::kNumberTypeFloat32; - } +ms::Tensor GetResultTensor(const ms::Tensor &t, const FFTParam ¶m, asdFftType type_id) { // for fft and ifft, the output shape is equal to input shape ShapeVector out_shape(t.shape()); @@ -41,17 +37,74 @@ ms::Tensor GetResultTensor(const ms::Tensor &t, const FFTParam ¶m) { return ms::Tensor(type_id, out_shape); } -ms::Tensor exec_asdsip_fft(const string &op_name, const FFTParam ¶m, const ms::Tensor &input) { +// 定义类型标签 +template +struct FFTConfig {}; + +// 别名便于使用 +using FFT_1In1Out = FFTConfig<1, 1>; +using FFT_1In2Out = FFTConfig<1, 2>; +using FFT_2In1Out = FFTConfig<2, 1>; +using FFT_2In2Out = FFTConfig<2, 2>; + +ms::Tensor exec_asdsip_fft(const string &op_name, const FFTParam ¶m, + const ms::Tensor &input, FFT_1In1Out) { MS_LOG(INFO) << "Run device task LaunchAsdSipFFT start, op_name: " << op_name << ", param.fftX: " << param.fftXSize << ", param.fftY: " << param.fftYSize << ", param.fftType: " << param.fftType << ", param.direction: " << param.direction << ", param.batchSize: " << param.batchSize << ", param.dimType: " << param.dimType; - auto output = GetResultTensor(input, param); + auto type_id = mindspore::TypeId::kNumberTypeComplex64; + if (param.fftType == asdFftType::ASCEND_FFT_C2R) { + type_id = mindspore::TypeId::kNumberTypeFloat32; + } + auto output = GetResultTensor(input, param, type_id); ms::pynative::RunAsdSipFFTOp(op_name, param, input, output); MS_LOG(INFO) << "Run device task LaunchAsdSipFFT end"; return output; } +std::vector exec_asdsip_fft(const string &op_name, const FFTParam ¶m, + const ms::Tensor &real_in, FFT_1In2Out) { + MS_LOG(INFO) << "Run device task LaunchAsdSipFFT start, op_name: " << op_name << ", param.fftX: " << param.fftXSize + << ", param.fftY: " << param.fftYSize << ", param.fftType: " << param.fftType << ", param.direction: " + << param.direction << ", param.batchSize: " << param.batchSize << ", param.dimType: " << param.dimType; + auto type_id = mindspore::TypeId::kNumberTypeFloat32; + auto real_out = GetResultTensor(real_in, param, type_id); + auto imag_out = GetResultTensor(real_in, param, type_id); + ms::pynative::RunAsdSipFFTOp(op_name, param, {real_in}, {real_out, imag_out}); + MS_LOG(INFO) << "Run device task LaunchAsdSipFFT end"; + + return {real_out, imag_out}; +} + +std::vector exec_asdsip_fft(const string &op_name, const FFTParam ¶m, + const ms::Tensor &real_in, const ms::Tensor &imag_in, FFT_2In2Out) { + MS_LOG(INFO) << "Run device task LaunchAsdSipFFT start, op_name: " << op_name << ", param.fftX: " << param.fftXSize + << ", param.fftY: " << param.fftYSize << ", param.fftType: " << param.fftType << ", param.direction: " + << param.direction << ", param.batchSize: " << param.batchSize << ", param.dimType: " << param.dimType; + auto type_id = mindspore::TypeId::kNumberTypeFloat32; + auto real_output = GetResultTensor(real_in, param, type_id); + auto imag_output = GetResultTensor(imag_in, param, type_id); + ms::pynative::RunAsdSipFFTOp(op_name, param, {real_in, imag_in}, {real_output, imag_output}); + MS_LOG(INFO) << "Run device task LaunchAsdSipFFT end"; + + return {real_output, imag_output}; +} + +ms::Tensor exec_asdsip_fft(const string &op_name, const FFTParam ¶m, + const ms::Tensor &real_in, const ms::Tensor &imag_in, FFT_2In1Out) { + MS_LOG(INFO) << "Run device task LaunchAsdSipFFT start, op_name: " << op_name << ", param.fftX: " << param.fftXSize + << ", param.fftY: " << param.fftYSize << ", param.fftType: " << param.fftType << ", param.direction: " + << param.direction << ", param.batchSize: " << param.batchSize << ", param.dimType: " << param.dimType; + auto type_id = mindspore::TypeId::kNumberTypeFloat32; + auto real_out = GetResultTensor(real_in, param, type_id); + ms::pynative::RunAsdSipFFTOp(op_name, param, {real_in, imag_in}, {real_out}); + MS_LOG(INFO) << "Run device task LaunchAsdSipFFT end"; + + return real_out; +} + + auto pyboost_npu_set_cache_size(int64_t cache_size) { MS_LOG(INFO) << "Set cache size for NPU FFT, cache_size: " << cache_size; ms::pynative::AsdSipFFTOpRunner::SetCacheSize(cache_size); @@ -68,7 +121,7 @@ auto pyboost_npu_fft_1d(const ms::Tensor &input, int64_t batch_size, int64_t x_s param.batchSize = batch_size; param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; const string op_name = "asdFftExecC2C"; - return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input, FFT_1In1Out{}); } // 1D FFT inverse @@ -81,7 +134,33 @@ auto pyboost_npu_ifft_1d(const ms::Tensor &input, int64_t batch_size, int64_t x_ param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; const string op_name = "asdFftExecC2C"; - return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input, FFT_1In1Out{}); +} + +// 1D FFT forward +auto pyboost_npu_fft_1d_sep(const ms::Tensor &real, const ms::Tensor &imag, int64_t batch_size, int64_t x_size) { + FFTParam param; + param.fftXSize = x_size; + param.fftYSize = 0; + param.fftType = asdFftType::ASCEND_FFT_C2C; + param.direction = asdFftDirection::ASCEND_FFT_FORWARD; + param.batchSize = batch_size; + param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; + const string op_name = "asdFftExecC2C"; + return ms::pynative::PyboostRunner::Call<2>(exec_asdsip_fft, op_name, param, real, imag, FFT_2In2Out{}); +} + +// 1D FFT inverse +auto pyboost_npu_ifft_1d_sep(const ms::Tensor &real, const ms::Tensor &imag, int64_t batch_size, int64_t x_size) { + ms::pynative::FFTParam param; + param.batchSize = batch_size; + param.fftXSize = x_size; + param.fftYSize = 0; + param.fftType = asdFftType::ASCEND_FFT_C2C; + param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; + param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; + const string op_name = "asdFftExecC2C"; + return ms::pynative::PyboostRunner::Call<2>(exec_asdsip_fft, op_name, param, real, imag, FFT_2In2Out{}); } // 1D RFFT forward @@ -94,7 +173,7 @@ auto pyboost_npu_rfft_1d(const ms::Tensor &input, int64_t batch_size, int64_t x_ param.direction = asdFftDirection::ASCEND_FFT_FORWARD; param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; const string op_name = "asdFftExecR2C"; - return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input, FFT_1In1Out{}); } // 1D IRFFT inverse @@ -107,7 +186,7 @@ auto pyboost_npu_irfft_1d(const ms::Tensor &input, int64_t batch_size, int64_t x param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; param.dimType = asdFft1dDimType::ASCEND_FFT_HORIZONTAL; const string op_name = "asdFftExecC2R"; - return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input, FFT_1In1Out{}); } // 2D FFT forward @@ -119,7 +198,7 @@ auto pyboost_npu_fft_2d(const ms::Tensor &input, int64_t batch_size, int64_t x_s param.direction = asdFftDirection::ASCEND_FFT_FORWARD; param.batchSize = batch_size; const string op_name = "asdFftExecC2C"; - return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input, FFT_1In1Out{}); } // 2D IFFT inverse @@ -131,7 +210,7 @@ auto pyboost_npu_ifft_2d(const ms::Tensor &input, int64_t batch_size, int64_t x_ param.fftType = asdFftType::ASCEND_FFT_C2C; param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; const string op_name = "asdFftExecC2C"; - return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input, FFT_1In1Out{}); } // 2D RFFT forward @@ -143,7 +222,7 @@ auto pyboost_npu_rfft_2d(const ms::Tensor &input, int64_t batch_size, int64_t x_ param.fftType = asdFftType::ASCEND_FFT_R2C; param.direction = asdFftDirection::ASCEND_FFT_FORWARD; const string op_name = "asdFftExecR2C"; - return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input, FFT_1In1Out{}); } // 2D IRFFT inverse @@ -155,7 +234,7 @@ auto pyboost_npu_irfft_2d(const ms::Tensor &input, int64_t batch_size, int64_t x param.fftType = asdFftType::ASCEND_FFT_C2R; param.direction = asdFftDirection::ASCEND_FFT_BACKWARD; const string op_name = "asdFftExecC2R"; - return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input); + return ms::pynative::PyboostRunner::Call<1>(exec_asdsip_fft, op_name, param, input, FFT_1In1Out{}); } // Python binding module @@ -177,6 +256,20 @@ PYBIND11_MODULE(MS_EXTENSION_NAME, m) { pybind11::arg("batch_size"), pybind11::arg("x_size")); + // 1D FFT SEP + m.def("asd_fft_1d_sep", &pyboost_npu_fft_1d_sep, "1D FFT on NPU", + pybind11::arg("real"), + pybind11::arg("imag"), + pybind11::arg("batch_size"), + pybind11::arg("x_size")); + + // 1D IFFT SEP + m.def("asd_ifft_1d_sep", &pyboost_npu_ifft_1d_sep, "1D IFFT on NPU", + pybind11::arg("real"), + pybind11::arg("imag"), + pybind11::arg("batch_size"), + pybind11::arg("x_size")); + // 1D RFFT m.def("asd_rfft_1d", &pyboost_npu_rfft_1d, "1D RFFT on NPU", pybind11::arg("input"), diff --git a/mindscience/solvers/cbs.py b/mindscience/solvers/cbs.py index 1ad9108f54b6ae8ded4a67485e17e62479d2d546..78af030427875876f651d6ecd685fd975fd1ea61 100644 --- a/mindscience/solvers/cbs.py +++ b/mindscience/solvers/cbs.py @@ -127,12 +127,14 @@ class CBSBlock(nn.Cell): ''' Run one iteration and return the incremental ''' vur, vui = self.op_v(ur, ui, vr, vi) gvr, gvi = self.op_g(vur + rhs, vui, gr, gi) - vgr, vgi = self.op_v(gvr - ur, gvi - ui, vr, vi) + gvr1 = mint.sub(gvr, ur) + gvi1 = mint.sub(gvi, ui) + vgr, vgi = self.op_v(gvr1, gvi1, vr, vi) # eps > 0: Convergent Born series; eps == 0: Original Born Series cond = ops.broadcast_to(eps, ur.shape) > 0 - dur = ops.select(cond, -vgi / (eps + 1e-8), gvr - ur) # '* (-1.)' comes from imag part multiplying i/eps - dui = ops.select(cond, vgr / (eps + 1e-8), gvi - ui) + dur = ops.select(cond, -vgi / (eps + 1e-8), gvr1) # '* (-1.)' comes from imag part multiplying i/eps + dui = ops.select(cond, vgr / (eps + 1e-8), gvi1) return ops.stack([dur, dui]) @@ -241,6 +243,8 @@ class CBS(nn.Cell): self.cbs_block = CBSBlock(self.shape_padded, self.btype) + self.dist = self.cbs_pml_dist(self.shape, self.dxs, self.pml_size) + def cbs_params(self, c_star, f_star): ''' Compute constant variables for CBS iteration ''' omg = 1.0 @@ -284,20 +288,28 @@ class CBS(nn.Cell): @staticmethod def cbs_pml(shape, dxs, k0, pml_size, alpha, rampup): ''' Construct the heterogeneous k field with PML BC embedded ''' - def num(x): - num_real = (alpha ** 2) * (rampup - alpha * x) * ((alpha * x) ** (rampup - 1)) - num_imag = (alpha ** 2) * (2 * k0 * x) * ((alpha * x) ** (rampup - 1)) - return num_real, num_imag - - def den(x): - return sum([(alpha * x) ** i / float(factorial(i)) for i in range(rampup + 1)]) * factorial(rampup) + def num_den(x): + alpha_2 = mint.mul(alpha, alpha) + alpha_x = mint.mul(alpha, x) + num_real = alpha_2 * (rampup - alpha_x) * (alpha_x ** (rampup - 1)) + num_imag = alpha_2 * (2 * k0 * x) * (alpha_x ** (rampup - 1)) + den_x = sum([alpha_x ** i / float(factorial(i)) for i in range(rampup + 1)]) * factorial(rampup) + return num_real, num_imag, den_x def transform_fun(x): - num_real, num_imag = num(x) - den_x = den(x) + num_real, num_imag, den_x = num_den(x) transform_real, transform_imag = num_real / den_x, num_imag / den_x return transform_real, transform_imag + k_k0_real, k_k0_imag = transform_fun(self.dist) + ksq_r = k_k0_real + k0 ** 2 + ksq_i = k_k0_imag + + return ksq_r, ksq_i + + def cbs_pml_dist(self, shape, dxs, pml_size): + ''' Construct the pml dist in init function ''' + def pml_padding(n, d, s1, s2): original = (ops.abs(mnp.linspace(1 - n, n - 1, n)) - n) * (d / 2) left = ops.arange((s1 - 0.5)*d, -d/2, -d, dtype=ms.float32) @@ -313,11 +325,7 @@ class CBS(nn.Cell): diff *= (diff > 0).astype(ms.float32) / 4. dist = ops.norm(diff, dim=0) - k_k0_real, k_k0_imag = transform_fun(dist) - ksq_r = k_k0_real + k0 ** 2 - ksq_i = k_k0_imag - - return ksq_r, ksq_i + return dist @staticmethod def extrapolate(data, padding, values): @@ -414,7 +422,7 @@ class CBS(nn.Cell): nz, ny, nx = ur.shape[-self.dim:] ur = ur[..., n0[0]:nz - n0[1], n0[2]:ny - n0[3], n0[4]:nx - n0[5]] ui = ui[..., n0[0]:nz - n0[1], n0[2]:ny - n0[3], n0[4]:nx - n0[5]] - ui *= -1. + ui = mint.mul(ui, -1.) # Note: the conjugate here is because we define Fourier modes differently to JAX in that the frequencies # are opposite, leading to opposite attenuation in PML, and finally the conjugation in results