diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py index 761f3049ebdbe8d13f1507cf3af9c7f8e05755e5..afac04d276120e355752ae713b6dc0316f1b96df 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py @@ -685,14 +685,25 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + if hasattr(self, "skip_strategy"): + noise_pred = self.skip_strategy( + self.transformer, + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + else: + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() @@ -720,6 +731,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ) latents = latents.to(prompt_embeds.dtype) + if hasattr(self, "skip_strategy"): + self.skip_strategy.update_strategy(latents) + # call the callback, if provided if callback_on_step_end is not None: callback_kwargs = {} diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py index 6b2d8bd1a94f0c79e79c0acc0c0d4fe30fabe73f..7e72f78fca6a21ebe54a0817cb83690d9073954a 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py @@ -14,6 +14,7 @@ from typing import List, Optional, Tuple, Union, Literal from cogvideox_5b import CogVideoXPipeline, CogVideoXTransformer3DModel, get_rank, get_world_size, all_gather +from mindiesd.pipeline.sampling_optm import AdaStep def parallelize_transformer(pipe): transformer = pipe.transformer @@ -122,6 +123,10 @@ def generate_video( pipe.vae.enable_slicing() pipe.vae.enable_tiling() pipe.transformer.switch_to_qkvLinear() + # sampling optm + skip_strategy = AdaStep(skip_thr=0.006, max_skip_steps=1, decay_ratio=0.99, device="npu") + pipe.skip_strategy = skip_strategy + if get_world_size() > 1: parallelize_transformer(pipe)