diff --git a/MindFlow/applications/physics_driven/boltzmann/src/boltzmann.py b/MindFlow/applications/physics_driven/boltzmann/src/boltzmann.py index e2b95171410d938c7786b3e58fdee3e9640fe02a..42f54f555ada1eaaa8608cb7d0a9ac970b7a0e9b 100644 --- a/MindFlow/applications/physics_driven/boltzmann/src/boltzmann.py +++ b/MindFlow/applications/physics_driven/boltzmann/src/boltzmann.py @@ -149,7 +149,10 @@ class FSMLAKernel(nn.Cell): def construct(self, x, kn_bzm=1.0): ff = x @ self.f - q = ops.einsum("...i,...j,ijk->...k", ff, ff, self.k) + reduce_sum = ops.ReduceSum(keep_dims=False) + tmp = ff.unsqueeze(-1)*ff.unsqueeze(-2) + q = reduce_sum(tmp.unsqueeze(-1)*self.k, axis=-2) + q = reduce_sum(q, axis=-2) qr = q @ self.g.T return qr / kn_bzm diff --git a/MindFlow/applications/physics_driven/boltzmann/src/cells.py b/MindFlow/applications/physics_driven/boltzmann/src/cells.py index db7a9a8d228ed5e05b29adbd00a27850d3656ccf..68536c77532d5684bb18868920ecdad4149e002b 100644 --- a/MindFlow/applications/physics_driven/boltzmann/src/cells.py +++ b/MindFlow/applications/physics_driven/boltzmann/src/cells.py @@ -399,10 +399,11 @@ class MaxwellianLR(nn.Cell): def f_sum_lowrank(ft, wt): fx, fy, fz = ft wx, wy, wz = wt - sx = ops.einsum("...ir,i->...r", fx, wx) - sy = ops.einsum("...jr,j->...r", fy, wy) - sz = ops.einsum("...kr,k->...r", fz, wz) - s = ops.einsum("...r,...r,...r->...", sx, sy, sz) + reduce_sum = ops.ReduceSum(keep_dims=False) + sx = reduce_sum(fx*wx.unsqueeze(-1), axis=-2) + sy = reduce_sum(fy*wy.unsqueeze(-1), axis=-2) + sz = reduce_sum(fz*wz.unsqueeze(-1), axis=-2) + s = reduce_sum(sx*sy*sz, axis=-1) return s