From 0a6f43b801bc2b3d47c3498133a947f2644ad5c7 Mon Sep 17 00:00:00 2001 From: panyiwei1994 <11376635+panyiwei1994@user.noreply.gitee.com> Date: Tue, 16 Jul 2024 09:39:32 +0000 Subject: [PATCH] =?UTF-8?q?update=20PyTorch/built-in/mlm/OpenSora1.1/opens?= =?UTF-8?q?ora/models/layers/blocks.py.=20=E9=80=82=E9=85=8D=E5=B9=B6?= =?UTF-8?q?=E8=A1=8C=E5=9C=BA=E6=99=AF=E4=B8=8B=E7=9A=84nfa?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: panyiwei1994 <11376635+panyiwei1994@user.noreply.gitee.com> --- .../built-in/mlm/OpenSora1.1/opensora/models/layers/blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/PyTorch/built-in/mlm/OpenSora1.1/opensora/models/layers/blocks.py b/PyTorch/built-in/mlm/OpenSora1.1/opensora/models/layers/blocks.py index cca035ac08..010f772849 100644 --- a/PyTorch/built-in/mlm/OpenSora1.1/opensora/models/layers/blocks.py +++ b/PyTorch/built-in/mlm/OpenSora1.1/opensora/models/layers/blocks.py @@ -291,7 +291,7 @@ class SeqParallelAttention(Attention): if self.enable_flashattn: if is_npu_available() and q.dtype in [torch.float16, torch.bfloat16]: x = torch_npu.npu_fusion_attention( - q, k, v, self.num_heads, input_layout="BSND", + q, k, v, self.num_heads // sp_size, input_layout="BSND", pse=None, scale=self.scale, pre_tockens=65536, @@ -449,7 +449,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention): ans += m actual_seq_kvlen.append(ans) x = torch_npu.npu_fusion_attention( - q, k, v, self.num_heads, input_layout="TND", + q, k, v, self.num_heads // sp_size, input_layout="TND", pse=None, scale=1.0 / math.sqrt(self.head_dim), pre_tockens=65536, -- Gitee