diff --git a/PyTorch/built-in/mlm/OpenSora1.1/opensora/models/layers/blocks.py b/PyTorch/built-in/mlm/OpenSora1.1/opensora/models/layers/blocks.py index cca035ac0842776fcce3c84f0776affe77541500..010f772849fc6928dbb4cf4c5d8dbeddd9325ce4 100644 --- a/PyTorch/built-in/mlm/OpenSora1.1/opensora/models/layers/blocks.py +++ b/PyTorch/built-in/mlm/OpenSora1.1/opensora/models/layers/blocks.py @@ -291,7 +291,7 @@ class SeqParallelAttention(Attention): if self.enable_flashattn: if is_npu_available() and q.dtype in [torch.float16, torch.bfloat16]: x = torch_npu.npu_fusion_attention( - q, k, v, self.num_heads, input_layout="BSND", + q, k, v, self.num_heads // sp_size, input_layout="BSND", pse=None, scale=self.scale, pre_tockens=65536, @@ -449,7 +449,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention): ans += m actual_seq_kvlen.append(ans) x = torch_npu.npu_fusion_attention( - q, k, v, self.num_heads, input_layout="TND", + q, k, v, self.num_heads // sp_size, input_layout="TND", pse=None, scale=1.0 / math.sqrt(self.head_dim), pre_tockens=65536,