From c0aa04033f222f78cbe1c3537fb2ec40f09baad1 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 21 Jan 2025 17:29:08 +0800 Subject: [PATCH] [fix]: replace einsum --- .../physics_driven/boltzmann/src/boltzmann.py | 5 ++++- .../applications/physics_driven/boltzmann/src/cells.py | 9 +++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/MindFlow/applications/physics_driven/boltzmann/src/boltzmann.py b/MindFlow/applications/physics_driven/boltzmann/src/boltzmann.py index e2b951714..42f54f555 100644 --- a/MindFlow/applications/physics_driven/boltzmann/src/boltzmann.py +++ b/MindFlow/applications/physics_driven/boltzmann/src/boltzmann.py @@ -149,7 +149,10 @@ class FSMLAKernel(nn.Cell): def construct(self, x, kn_bzm=1.0): ff = x @ self.f - q = ops.einsum("...i,...j,ijk->...k", ff, ff, self.k) + reduce_sum = ops.ReduceSum(keep_dims=False) + tmp = ff.unsqueeze(-1)*ff.unsqueeze(-2) + q = reduce_sum(tmp.unsqueeze(-1)*self.k, axis=-2) + q = reduce_sum(q, axis=-2) qr = q @ self.g.T return qr / kn_bzm diff --git a/MindFlow/applications/physics_driven/boltzmann/src/cells.py b/MindFlow/applications/physics_driven/boltzmann/src/cells.py index db7a9a8d2..68536c775 100644 --- a/MindFlow/applications/physics_driven/boltzmann/src/cells.py +++ b/MindFlow/applications/physics_driven/boltzmann/src/cells.py @@ -399,10 +399,11 @@ class MaxwellianLR(nn.Cell): def f_sum_lowrank(ft, wt): fx, fy, fz = ft wx, wy, wz = wt - sx = ops.einsum("...ir,i->...r", fx, wx) - sy = ops.einsum("...jr,j->...r", fy, wy) - sz = ops.einsum("...kr,k->...r", fz, wz) - s = ops.einsum("...r,...r,...r->...", sx, sy, sz) + reduce_sum = ops.ReduceSum(keep_dims=False) + sx = reduce_sum(fx*wx.unsqueeze(-1), axis=-2) + sy = reduce_sum(fy*wy.unsqueeze(-1), axis=-2) + sz = reduce_sum(fz*wz.unsqueeze(-1), axis=-2) + s = reduce_sum(sx*sy*sz, axis=-1) return s -- Gitee