diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py index 54841d240516af52964aa211fa187ae208337579..2b01678dc17c425a9dd869fb8753951f2e07b19a 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py @@ -1896,9 +1896,12 @@ class CogVideoXAttnProcessor2_0: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) + if hasattr(attn, "qkvLinear"): + query, key, value = attn.qkvLinear(hidden_states) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py index 0cdc4a0cdd8008862bfd69f37abcc1e0156b0098..7a9f62b5453fa48ca8fb53ccf14c77c70890a0a7 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py @@ -28,7 +28,7 @@ from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero - +from mindiesd.layers.linear import QKVLinear logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -239,6 +239,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ): super().__init__() inner_dim = num_attention_heads * attention_head_dim + self.num_heads = num_attention_heads + self.head_dim = attention_head_dim if not use_rotary_positional_embeddings and use_learned_positional_embeddings: raise ValueError( @@ -504,3 +506,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) + + def switch_to_qkvLinear(self) -> None: + for blk in self.transformer_blocks: + blk.attn1.qkvLinear = QKVLinear(self.head_dim, self.head_dim * self.num_heads) + blk.attn1.qkvLinear.weight.data = torch.cat((blk.attn1.to_q.weight.data.transpose(1, 0).contiguous(), blk.attn1.to_k.weight.data.transpose(1, 0).contiguous(), blk.attn1.to_v.weight.data.transpose(1, 0).contiguous()), -1) + blk.attn1.qkvLinear.bias.data = torch.cat((blk.attn1.to_q.bias.data, blk.attn1.to_k.bias.data, blk.attn1.to_v.bias.data), -1) + blk.attn1.to_q = None + blk.attn1.to_k = None + blk.attn1.to_v = None \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py index 831107e7aea87823bc50575e2b094f4120ff2e75..6b2d8bd1a94f0c79e79c0acc0c0d4fe30fabe73f 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py @@ -121,6 +121,7 @@ def generate_video( pipe.vae = pipe.vae.half() pipe.vae.enable_slicing() pipe.vae.enable_tiling() + pipe.transformer.switch_to_qkvLinear() if get_world_size() > 1: parallelize_transformer(pipe)