From c9d63e6ad4a9ff12eb9515753cbcf8f93ac7517a Mon Sep 17 00:00:00 2001 From: lyu-xingjia Date: Mon, 13 Jan 2025 21:28:24 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90MindIE-SD=E3=80=91CogVideo=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E4=BD=BF=E8=83=BDqkvLinear?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cogvideox_5b/models/attention_processor.py | 9 ++++++--- .../models/transformers/cogvideox_transformer_3d.py | 13 ++++++++++++- .../built-in/foundation/CogVideoX-5b/inference.py | 1 + 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py index 54841d2405..2b01678dc1 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py @@ -1896,9 +1896,12 @@ class CogVideoXAttnProcessor2_0: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) + if hasattr(attn, "qkvLinear"): + query, key, value = attn.qkvLinear(hidden_states) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py index 0cdc4a0cdd..7a9f62b545 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py @@ -28,7 +28,7 @@ from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero - +from mindiesd.layers.linear import QKVLinear logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -239,6 +239,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ): super().__init__() inner_dim = num_attention_heads * attention_head_dim + self.num_heads = num_attention_heads + self.head_dim = attention_head_dim if not use_rotary_positional_embeddings and use_learned_positional_embeddings: raise ValueError( @@ -504,3 +506,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) + + def switch_to_qkvLinear(self) -> None: + for blk in self.transformer_blocks: + blk.attn1.qkvLinear = QKVLinear(self.head_dim, self.head_dim * self.num_heads) + blk.attn1.qkvLinear.weight.data = torch.cat((blk.attn1.to_q.weight.data.transpose(1, 0).contiguous(), blk.attn1.to_k.weight.data.transpose(1, 0).contiguous(), blk.attn1.to_v.weight.data.transpose(1, 0).contiguous()), -1) + blk.attn1.qkvLinear.bias.data = torch.cat((blk.attn1.to_q.bias.data, blk.attn1.to_k.bias.data, blk.attn1.to_v.bias.data), -1) + blk.attn1.to_q = None + blk.attn1.to_k = None + blk.attn1.to_v = None \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py index 831107e7ae..6b2d8bd1a9 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py @@ -121,6 +121,7 @@ def generate_video( pipe.vae = pipe.vae.half() pipe.vae.enable_slicing() pipe.vae.enable_tiling() + pipe.transformer.switch_to_qkvLinear() if get_world_size() > 1: parallelize_transformer(pipe) -- Gitee