diff --git a/MindFlow/mindflow/core/fourier.py b/MindFlow/mindflow/core/fourier.py index 9425d32da57e12d0191b45a11678f7b3896969c1..17c5c64678f52470cdbdee447189327a2757a577 100644 --- a/MindFlow/mindflow/core/fourier.py +++ b/MindFlow/mindflow/core/fourier.py @@ -626,8 +626,8 @@ class IDCT(nn.Cell): c = self.dft_cell(vr, vi) # (..., n) c1 = c[..., :(n + 1) // 2] c2 = self.fliper(c[..., (n + 1) // 2:], dims=-1) - d1 = ops.pad(c1[..., None], (0, 1)).reshape(*c1.shape[:-1], -1) - d2 = ops.pad(c2[..., None], (1, 0)).reshape(*c2.shape[:-1], -1) + d1 = ops.pad(c1.reshape(-1)[..., None], (0, 1)).reshape(*c1.shape[:-1], -1) + d2 = ops.pad(c2.reshape(-1)[..., None], (1, 0)).reshape(*c2.shape[:-1], -1) return d1 + d2