diff --git a/examples/stepvideo/i2v/061.png b/examples/stepvideo/i2v/061.png new file mode 100644 index 0000000000000000000000000000000000000000..1533859673b185df24fd13c74c3ba56eaa655ad6 Binary files /dev/null and b/examples/stepvideo/i2v/061.png differ diff --git a/examples/stepvideo/i2v/inference_i2v.sh b/examples/stepvideo/i2v/inference_i2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..5a8ea272f91af936d5560ac2b7cf38826673cc2b --- /dev/null +++ b/examples/stepvideo/i2v/inference_i2v.sh @@ -0,0 +1,71 @@ +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export HCCL_CONNECT_TIMEOUT=1200 +export ASCEND_LAUNCH_BLOCKING=1 + +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +NPUS_PER_NODE=4 +WORLD_SIZE=$(($NPUS_PER_NODE * $NNODES)) + +TP=4 +PP=1 +CP=1 +MBS=1 +GBS=$(($WORLD_SIZE*$MBS/$CP/$TP)) + +MM_MODEL="examples/stepvideo/i2v/inference_i2v_model.json" +LOAD_PATH="your_converted_dit_ckpt_dir" + +DISTRIBUTED_ARGS=" + --nproc_per_node $NPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" +MM_ARGS=" + --mm-model $MM_MODEL +" + +SORA_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --context-parallel-size ${CP} \ + --micro-batch-size ${MBS} \ + --global-batch-size ${GBS} \ + --num-layers 28 \ + --hidden-size 1152 \ + --num-attention-heads 16 \ + --seq-length 1024\ + --max-position-embeddings 1024 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --tokenizer-type NullTokenizer \ + --vocab-size 0 \ + --position-embedding-type rope \ + --rotary-base 500000 \ + --swiglu \ + --no-masked-softmax-fusion \ + --lr 2e-5 \ + --min-lr 2e-5 \ + --train-iters 5010 \ + --weight-decay 0 \ + --clip-grad 1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.999 \ + --no-gradient-accumulation-fusion \ + --no-load-optim \ + --no-load-rng \ + --no-save-optim \ + --no-save-rng \ + --bf16 \ + --load $LOAD_PATH \ + --distributed-timeout-minutes 20 \ +" + +torchrun $DISTRIBUTED_ARGS inference_sora.py $MM_ARGS $SORA_ARGS \ No newline at end of file diff --git a/examples/stepvideo/i2v/inference_i2v_model.json b/examples/stepvideo/i2v/inference_i2v_model.json new file mode 100644 index 0000000000000000000000000000000000000000..0ded251d758a966332a6e1329ccb9097411f04c4 --- /dev/null +++ b/examples/stepvideo/i2v/inference_i2v_model.json @@ -0,0 +1,92 @@ +{ + "predictor": { + "model_id": "stepvideodit", + "from_pretrained": null, + "dtype": "bf16", + "num_layers" : 48, + "num_attention_heads": 48, + "attention_head_dim": 128, + "channel_split": [64, 32, 32], + "in_channels": 64, + "out_channels": 64, + "dropout": 0.0, + "patch_size": 1, + "patch_size_thw": [1, 1, 1], + "norm_type": "ada_norm_single", + "norm_elementwise_affine": false, + "norm_eps": 1e-6, + "use_additional_conditions": true, + "caption_channels": [6144, 1024], + + "attention_norm_type": "rmsnorm", + "attention_norm_elementwise_affine": true, + "attention_norm_eps": 1e-6, + "fa_layout": "bsnd" + }, + "ae": { + "model_id": "stepvideovae", + "from_pretrained": "./weights/vae/vae_xx.safetensors", + "dtype": "bf16", + "z_channels": 64, + "frame_len": 17, + "version": 2 + }, + "tokenizer":[ + { + "autotokenizer_name": "stepchat", + "hub_backend": "hf", + "from_pretrained": "./weights/step_llm/step1_chat_tokenizer.model", + "model_max_length": 320 + }, + { + "autotokenizer_name": "BertTokenizer", + "hub_backend": "hf", + "from_pretrained": "./weights/hunyuan_clip/tokenizer", + "model_max_length": 77 + } + ], + "text_encoder": [ + { + "model_id": "StepLLmModel", + "hub_backend": "hf", + "from_pretrained": "./weights/step_llm/", + "dtype": "bf16" + }, + { + "model_id": "BertModel", + "hub_backend": "hf", + "from_pretrained": "./weights/hunyuan_clip/clip_text_encoder", + "dtype": "float32" + } + ], + + "diffusion": { + "model_id": "flow_match_discrete_scheduler", + "num_train_timesteps":1, + "num_inference_timesteps":50, + "shift": 13.0, + "reverse": false, + "solver": "euler" + }, + "pipeline_config": { + "version": "stepvideo", + "use_attention_mask": true, + "input_size": [102, 544, 992], + "guidance_scale": 9.0, + "model_type": "i2v", + "seed": 1234, + "motion_score": 5.0 + }, + "unload_text_encoder": true, + "micro_batch_size": 1, + "frame_interval":1, + "save_path":"examples/stepvideo/i2v/i2v_result/", + "fps":25, + "prompt": "examples/stepvideo/i2v/samples_i2v_prompts.txt", + "image": "examples/stepvideo/i2v/samples_i2v_images.txt", + "use_prompt_preprocess": false, + "pipeline_class": "StepVideoPipeline", + "device":"npu", + "dtype": "bf16" +} + diff --git a/examples/stepvideo/i2v/samples_i2v_images.txt b/examples/stepvideo/i2v/samples_i2v_images.txt new file mode 100644 index 0000000000000000000000000000000000000000..a67dcc0c9b7942fdeaf3e9aca419c614e1720bc4 --- /dev/null +++ b/examples/stepvideo/i2v/samples_i2v_images.txt @@ -0,0 +1 @@ +examples/stepvideo/i2v/061.png \ No newline at end of file diff --git a/examples/stepvideo/i2v/samples_i2v_prompts.txt b/examples/stepvideo/i2v/samples_i2v_prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..4fac71a06b1d08605969f188fa0e156db9700d10 --- /dev/null +++ b/examples/stepvideo/i2v/samples_i2v_prompts.txt @@ -0,0 +1 @@ +带翅膀的小老鼠先用爪子挠了挠脑袋,随后扑扇着翅膀飞了起来。 \ No newline at end of file diff --git a/examples/stepvideo/t2v/inference_t2v.sh b/examples/stepvideo/t2v/inference_t2v.sh new file mode 100644 index 0000000000000000000000000000000000000000..80dfdd215f2b8225bba19e065a7f255923464a7e --- /dev/null +++ b/examples/stepvideo/t2v/inference_t2v.sh @@ -0,0 +1,71 @@ +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export HCCL_CONNECT_TIMEOUT=1200 +export ASCEND_LAUNCH_BLOCKING=1 + +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +NPUS_PER_NODE=4 +WORLD_SIZE=$(($NPUS_PER_NODE * $NNODES)) + +TP=4 +PP=1 +CP=1 +MBS=1 +GBS=$(($WORLD_SIZE*$MBS/$CP/$TP)) + +MM_MODEL="examples/stepvideo/t2v/inference_t2v_model.json" +LOAD_PATH="your_converted_dit_ckpt_dir" + +DISTRIBUTED_ARGS=" + --nproc_per_node $NPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" +MM_ARGS=" + --mm-model $MM_MODEL +" + +SORA_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --context-parallel-size ${CP} \ + --micro-batch-size ${MBS} \ + --global-batch-size ${GBS} \ + --num-layers 28 \ + --hidden-size 1152 \ + --num-attention-heads 16 \ + --seq-length 1024\ + --max-position-embeddings 1024 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --tokenizer-type NullTokenizer \ + --vocab-size 0 \ + --position-embedding-type rope \ + --rotary-base 500000 \ + --swiglu \ + --no-masked-softmax-fusion \ + --lr 2e-5 \ + --min-lr 2e-5 \ + --train-iters 5010 \ + --weight-decay 0 \ + --clip-grad 1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.999 \ + --no-gradient-accumulation-fusion \ + --no-load-optim \ + --no-load-rng \ + --no-save-optim \ + --no-save-rng \ + --bf16 \ + --load $LOAD_PATH \ + --distributed-timeout-minutes 20 \ +" + +torchrun $DISTRIBUTED_ARGS inference_sora.py $MM_ARGS $SORA_ARGS \ No newline at end of file diff --git a/examples/stepvideo/t2v/inference_t2v_model.json b/examples/stepvideo/t2v/inference_t2v_model.json new file mode 100644 index 0000000000000000000000000000000000000000..7b89e1c6b2624f42288ce4ca24ad0cc8f14b24cf --- /dev/null +++ b/examples/stepvideo/t2v/inference_t2v_model.json @@ -0,0 +1,85 @@ +{ + "predictor": { + "model_id": "stepvideodit", + "from_pretrained": null, + "dtype": "bf16", + "num_layers" : 48, + "num_attention_heads": 48, + "attention_head_dim": 128, + "channel_split": [64, 32, 32], + "in_channels": 64, + "out_channels": 64, + "dropout": 0.0, + "patch_size": 1, + "patch_size_thw": [1, 1, 1], + "norm_type": "ada_norm_single", + "norm_elementwise_affine": false, + "norm_eps": 1e-6, + "use_additional_conditions": false, + "caption_channels": [6144, 1024] + }, + "ae": { + "model_id": "stepvideovae", + "from_pretrained": "./weights/vae/vae_xx.safetensors", + "dtype": "bf16", + "z_channels": 64, + "frame_len": 17, + "version": 2 + }, + "tokenizer":[ + { + "autotokenizer_name": "stepchat", + "hub_backend": "hf", + "from_pretrained": "./weights/step_llm/step1_chat_tokenizer.model", + "model_max_length": 320 + }, + { + "autotokenizer_name": "BertTokenizer", + "hub_backend": "hf", + "from_pretrained": "./weights/hunyuan_clip/tokenizer", + "model_max_length": 77 + } + ], + "text_encoder": [ + { + "model_id": "StepLLmModel", + "hub_backend": "hf", + "from_pretrained": "./weights/step_llm/", + "dtype": "bf16" + }, + { + "model_id": "BertModel", + "hub_backend": "hf", + "from_pretrained": "./weights/hunyuan_clip/clip_text_encoder", + "dtype": "float32" + } + ], + + "diffusion": { + "model_id": "flow_match_discrete_scheduler", + "num_train_timesteps":1, + "num_inference_timesteps":50, + "shift": 13.0, + "reverse": false, + "solver": "euler" + }, + "pipeline_config": { + "version": "stepvideo", + "use_attention_mask": true, + "input_size": [204, 768, 768], + "guidance_scale": 9.0, + "model_type": "t2v", + "seed": 1234 + }, + "unload_text_encoder": true, + "micro_batch_size": 1, + "frame_interval":1, + "save_path":"examples/stepvideo/t2v/t2v_result/", + "fps":25, + "prompt": "examples/stepvideo/t2v/samples_prompts.txt", + "use_prompt_preprocess": false, + "pipeline_class": "StepVideoPipeline", + "device":"npu", + "dtype": "bf16" +} + diff --git a/examples/stepvideo/t2v/samples_prompt.txt b/examples/stepvideo/t2v/samples_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..1faa56f16f1bc027e877a038b3e1a9427a688bc8 --- /dev/null +++ b/examples/stepvideo/t2v/samples_prompt.txt @@ -0,0 +1 @@ +视频中,一个身穿西装的小男孩,突然表情变得狰狞,身体逐渐被黑色的液体包裹,最终变身成为黑色毒液。这个过程在暗色调的环境中进行,背景较为模糊,突出表现了小男孩变身的每一个细节。视频采用特写镜头拍摄,具有科幻风格,清晰地展示了变身的每一个动作细节,给人以震撼感。 \ No newline at end of file diff --git a/mindspeed_mm/models/ae/base.py b/mindspeed_mm/models/ae/base.py index 56d92c93ed886d6d4d8abaee7024ef2cf8743f6b..2300c6aea415f35250a1216453f7c3fd609e12fc 100644 --- a/mindspeed_mm/models/ae/base.py +++ b/mindspeed_mm/models/ae/base.py @@ -23,6 +23,7 @@ from mindspeed_mm.models.ae.wfvae import WFVAE from mindspeed_mm.models.ae.contextparallel_causalvae import ContextParallelCasualVAE from mindspeed_mm.models.ae.autoencoder_kl_hunyuanvideo import AutoencoderKLHunyuanVideo from mindspeed_mm.models.ae.wan_video_vae import WanVideoVAE +from mindspeed_mm.models.ae.stepvideo_vae import StepVideoVae AE_MODEL_MAPPINGS = { @@ -32,7 +33,8 @@ AE_MODEL_MAPPINGS = { "wfvae": WFVAE, "contextparallelcasualvae": ContextParallelCasualVAE, "autoencoder_kl_hunyuanvideo": AutoencoderKLHunyuanVideo, - "wan_video_vae": WanVideoVAE + "wan_video_vae": WanVideoVAE, + "stepvideovae": StepVideoVae } diff --git a/mindspeed_mm/models/ae/stepvideo_vae.py b/mindspeed_mm/models/ae/stepvideo_vae.py index 26cfbf6b41ac7be363d5b05836d46be53c282767..09268484c71d034dec3d582105822aa1d9bcd802 100644 --- a/mindspeed_mm/models/ae/stepvideo_vae.py +++ b/mindspeed_mm/models/ae/stepvideo_vae.py @@ -132,7 +132,7 @@ def base_conv3d_channel_last(x, conv_layer, residual=None): conv_result = conv_result.clone() out_nhwci.copy_(conv_result) - if conv_result.shape == out_nhwci.shape or conv_result.dtype == out_nhwci.dtype: + if conv_result.shape != out_nhwci.shape or conv_result.dtype != out_nhwci.dtype: raise Exception(f"conv_result shape [{conv_result.shape}] must be the same as " f"out_nhwci shape [{out_nhwci.shape}], and conv_result dtype [{conv_result.dtype}] " f"must be the same as out_nhwci dtype [{out_nhwci.dtype}]") @@ -690,7 +690,20 @@ class StepVideoVae(nn.Module): self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 if from_pretrained is not None: - load_checkpoint(self, from_pretrained) + weight_dict = self.init_from_ckpt(from_pretrained) + if len(weight_dict) != 0: + self.load_state_dict(weight_dict) + + def init_from_ckpt(self, model_path): + from safetensors import safe_open + p = {} + with safe_open(model_path, framework="pt", device="cpu") as f: + for k in f.keys(): + tensor = f.get_tensor(k) + if k.startswith("decoder.conv_out."): + k = k.replace("decoder.conv_out.", "decoder.conv_out.conv.") + p[k] = tensor + return p def naive_encode(self, x): length = x.size(1) diff --git a/mindspeed_mm/models/diffusion/flow_match_discrete_scheduler.py b/mindspeed_mm/models/diffusion/flow_match_discrete_scheduler.py index 7b42dac0dcf7ae5a78087b7c1bda976d18219bee..10e3791d54a6fea36a5dfc975e9b4abe15e61f83 100644 --- a/mindspeed_mm/models/diffusion/flow_match_discrete_scheduler.py +++ b/mindspeed_mm/models/diffusion/flow_match_discrete_scheduler.py @@ -306,6 +306,7 @@ class FlowMatchDiscreteScheduler: i2v_condition_type: str = "token_replace", **kwargs ) -> torch.Tensor: + extra_step_kwargs = {} if extra_step_kwargs is None else extra_step_kwargs dtype = latents.dtype # denoising loop num_inference_steps = self.num_train_timesteps if self.num_inference_timesteps is None else self.num_inference_timesteps diff --git a/mindspeed_mm/models/predictor/dits/step_video_dit.py b/mindspeed_mm/models/predictor/dits/step_video_dit.py index dd0f459239681327f0b3a23aa80fb1dddc1d2f0c..bd957a00d8c259d1ca90a725d547a2aa497a687f 100644 --- a/mindspeed_mm/models/predictor/dits/step_video_dit.py +++ b/mindspeed_mm/models/predictor/dits/step_video_dit.py @@ -1,3 +1,15 @@ +# Copyright 2025 StepFun Inc. All Rights Reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# ============================================================================== from typing import Optional, Dict, Tuple from contextlib import nullcontext @@ -65,7 +77,7 @@ class StepVideoDiT(MultiModalModule): self.pos_embed = PatchEmbed( patch_size=patch_size, - in_channels=self.in_channels, + in_channels=self.in_channels if not use_additional_conditions else in_channels * 2, embed_dim=self.inner_dim ) @@ -92,7 +104,7 @@ class StepVideoDiT(MultiModalModule): # 3. Output blocks. self.norm_out = nn.LayerNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) - self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim ** 0.5) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels) self.patch_size = patch_size @@ -104,7 +116,7 @@ class StepVideoDiT(MultiModalModule): caption_channel = self.caption_channels else: caption_channel, clip_channel = self.caption_channels - self.clip_projection = nn.Linear(clip_channel, self.inner_dim) + self.clip_projection = nn.Linear(clip_channel, self.inner_dim) self.caption_norm = nn.LayerNorm(caption_channel, eps=norm_eps, elementwise_affine=norm_elementwise_affine) @@ -122,15 +134,18 @@ class StepVideoDiT(MultiModalModule): buffers = tuple(self.buffers()) return buffers[0].dtype - def patchfy(self, hidden_states): + def patchfy(self, hidden_states, condition_hidden_states): + if condition_hidden_states is not None: + hidden_states = torch.cat([hidden_states, condition_hidden_states], dim=2) hidden_states = rearrange(hidden_states, 'b f c h w -> (b f) c h w') hidden_states = self.pos_embed(hidden_states) return hidden_states def prepare_attn_mask(self, encoder_attention_mask, encoder_hidden_states, q_seqlen): kv_seqlens = encoder_attention_mask.sum(dim=1).int() - mask = torch.ones([len(kv_seqlens), q_seqlen, max(kv_seqlens)], dtype=torch.bool, device=encoder_attention_mask.device) - encoder_hidden_states = encoder_hidden_states.squeeze(1) # b 1 s d -> b s d + mask = torch.ones([len(kv_seqlens), q_seqlen, max(kv_seqlens)], dtype=torch.bool, + device=encoder_attention_mask.device) + encoder_hidden_states = encoder_hidden_states.squeeze(1) encoder_hidden_states = encoder_hidden_states[:, : max(kv_seqlens)] for i, kv_len in enumerate(kv_seqlens): mask[i, :, :kv_len] = 0 @@ -156,16 +171,20 @@ class StepVideoDiT(MultiModalModule): else: rng_context = nullcontext() - encoder_hidden_states = prompt[0]# b 1 s d - encoder_hidden_states_2 = prompt[1]# b 1 s d + encoder_hidden_states = prompt[0] + encoder_hidden_states_2 = prompt[1] + motion_score = kwargs.get("motion_score") + condition_hidden_states = kwargs.get("image_latents") # Only retain stepllm's mask if isinstance(prompt_mask, list): encoder_attention_mask = prompt_mask[0] # Padding 1 on the mask of the stepllm len_clip = encoder_hidden_states_2.shape[2] - encoder_attention_mask = encoder_attention_mask.squeeze(1).to(hidden_states.device) # stepchat_tokenizer_mask: b 1 s => b s - encoder_attention_mask = torch.nn.functional.pad(encoder_attention_mask, (len_clip, 0), value=1)# pad attention_mask with clip's length + encoder_attention_mask = encoder_attention_mask.squeeze(1).to( + hidden_states.device) # stepchat_tokenizer_mask: b 1 s => b s + encoder_attention_mask = torch.nn.functional.pad(encoder_attention_mask, (len_clip, 0), + value=1) # pad attention_mask with clip's length bsz, frame, _, height, width = hidden_states.shape if mpu.get_context_parallel_world_size() > 1: @@ -174,15 +193,22 @@ class StepVideoDiT(MultiModalModule): grad_scale='down') height, width = height // self.patch_size, width // self.patch_size - hidden_states = self.patchfy(hidden_states) + hidden_states = self.patchfy(hidden_states, condition_hidden_states) len_frame = hidden_states.shape[1] if self.use_additional_conditions: - added_cond_kwargs = { - "resolution": torch.tensor([(height, width)] * bsz, device=hidden_states.device, dtype=hidden_states.dtype), - "nframe": torch.tensor([frame] * bsz, device=hidden_states.device, dtype=hidden_states.dtype), - "fps": fps - } + if condition_hidden_states is not None: + added_cond_kwargs = { + "motion_score": torch.tensor([motion_score], device=hidden_states.device, + dtype=hidden_states.dtype).repeat(bsz) + } + else: + added_cond_kwargs = { + "resolution": torch.tensor([(height, width)] * bsz, device=hidden_states.device, + dtype=hidden_states.dtype), + "nframe": torch.tensor([frame] * bsz, device=hidden_states.device, dtype=hidden_states.dtype), + "fps": fps + } else: added_cond_kwargs = {} @@ -197,8 +223,9 @@ class StepVideoDiT(MultiModalModule): hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous() - encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask, encoder_hidden_states, q_seqlen=frame * len_frame) - + encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask, encoder_hidden_states, + q_seqlen=frame * len_frame) + # Rotary positional embeddings rotary_pos_emb = self.rope(bsz, frame * mpu.get_context_parallel_world_size(), height, width, hidden_states.device)# s b 1 d if mpu.get_context_parallel_world_size() > 1: @@ -451,8 +478,9 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) self.nframe_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.fps_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.motion_score_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - def forward(self, timestep, resolution=None, nframe=None, fps=None): + def forward(self, timestep, resolution=None, nframe=None, fps=None, motion_score=None): hidden_dtype = next(self.timestep_embedder.parameters()).dtype timesteps_proj = self.time_proj(timestep) @@ -460,11 +488,9 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): if self.use_additional_conditions: batch_size = timestep.shape[0] - resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) - resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) - nframe_emb = self.additional_condition_proj(nframe.flatten()).to(hidden_dtype) - nframe_emb = self.nframe_embedder(nframe_emb).reshape(batch_size, -1) - conditioning = timesteps_emb + resolution_emb + nframe_emb + motion_score_emb = self.additional_condition_proj(motion_score.flatten()).to(hidden_dtype) + motion_score_emb = self.motion_score_embedder(motion_score_emb).reshape(batch_size, -1) + conditioning = timesteps_emb + motion_score_emb if fps is not None: fps_emb = self.additional_condition_proj(fps.flatten()).to(hidden_dtype) diff --git a/mindspeed_mm/models/text_encoder/stepllm_text_encoder.py b/mindspeed_mm/models/text_encoder/stepllm_text_encoder.py index 2b2744bfc89102bf839b6d8bab2b2deb64ae7a6a..dbe8250ed55cb19537cd72a9466a3fc08c107e58 100644 --- a/mindspeed_mm/models/text_encoder/stepllm_text_encoder.py +++ b/mindspeed_mm/models/text_encoder/stepllm_text_encoder.py @@ -10,6 +10,7 @@ # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # ============================================================================== +import math from typing import Optional from einops import rearrange @@ -17,11 +18,36 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch_npu +import numpy as np from transformers.modeling_utils import PretrainedConfig, PreTrainedModel from mindspeed_mm.models.common.normalize import normalize +DTYPE_FP16_MIN = float(np.finfo(np.float16).min) + + +def _get_alibi_slopes(n_heads): + n = 2 ** math.floor(math.log2(n_heads)) + m0 = torch.tensor(2.0 ** (-8.0 / n), dtype=torch.float32).to("cpu") + slopes = torch.pow(m0, torch.arange(1, n + 1, dtype=torch.float32).to("cpu")) + if n < n_heads: + m1 = torch.tensor(2.0**(-4.0 / n), dtype=torch.float32).to("cpu") + mm = torch.pow(m1, torch.arange(1, 1 + 2 * (n_heads - n), 2, dtype=torch.float32).to("cpu")) + slopes = torch.cat([slopes, mm]) + return slopes + + +def _get_mask(seq_len, b, n): + slopes = _get_alibi_slopes(n) + tril = torch.tril(torch.ones((1, 1, seq_len, seq_len), dtype=torch.bool)).to(torch.int32) + bias_row = torch.arange(seq_len).view(1, -1) + bias_cols = torch.arange(seq_len).view(-1, 1) + bias = -torch.sqrt(bias_cols - bias_row) + bias = bias.view(1, seq_len, seq_len) * slopes.view(-1, 1, 1) + bias = bias.masked_fill(tril == 0, DTYPE_FP16_MIN) + return bias + class LLaMaEmbedding(nn.Module): """Language model embeddings.""" @@ -53,11 +79,13 @@ class FlashSelfAttention(nn.Module): def forward(self, q, k, v, cu_seqlens=None): if cu_seqlens is None: - atten_mask_npu = torch.triu(torch.ones([2048, 2048], device="npu"), diagonal=1).bool() - head_num = q.shape[2] - scale = q.size(-1) ** (-0.5) - output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=1.0, - scale=scale, atten_mask=atten_mask_npu, sparse_mode=3)[0] + alibi_mask = _get_mask(q.size(1), q.size(0), q.size(2)) + alibi_mask = alibi_mask[:, :q.size(2), :, :].to(q.dtype).to(q.device) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=alibi_mask) + output = output.transpose(1, 2) else: raise ValueError('cu_seqlens is not supported!') diff --git a/mindspeed_mm/tasks/inference/pipeline/stepvideo_pipeline.py b/mindspeed_mm/tasks/inference/pipeline/stepvideo_pipeline.py index 8ebcc9f13df095ed811d718565be2056fa82a572..f9b634c3f2d8be171587b56b2076985ab2c06f3f 100644 --- a/mindspeed_mm/tasks/inference/pipeline/stepvideo_pipeline.py +++ b/mindspeed_mm/tasks/inference/pipeline/stepvideo_pipeline.py @@ -1,15 +1,32 @@ +from dataclasses import dataclass from typing import Optional, Union, List +import PIL +import numpy as np import torch -from megatron.training import get_args +from torchvision import transforms from mindspeed_mm.tasks.inference.pipeline.pipeline_base import MMPipeline from mindspeed_mm.tasks.inference.pipeline.pipeline_mixin.encode_mixin import MMEncoderMixin from mindspeed_mm.tasks.inference.pipeline.pipeline_mixin.inputs_checks_mixin import InputsCheckMixin +POSITIVE_MAGIC_T2V = "超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。" +NEGATIVE_MAGIC_T2V = "画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。" -POSITIVE_MAGIC = "超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。" -NEGATIVE_MAGIC = "画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。" +POSITIVE_MAGIC_I2V = "画面中的主体动作表现生动自然、画面流畅、生动细节、光线统一柔和、超真实动态捕捉、大师级运镜、整体不变形、超高清、画面稳定、逼真的细节、专业级构图、超细节、清晰。" +NEGATIVE_MAGIC_I2V = "动画、模糊、变形、毁容、低质量、拼贴、粒状、标志、抽象、插图、计算机生成、扭曲、动作不流畅、面部有褶皱、表情僵硬、畸形手指" + + +@dataclass +class ImageLatentsConfig: + image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image]]] + batch_size: int + num_channels_latents: int + height: int + width: int + num_frames: int + device: torch.device + dtype: torch.dtype class StepVideoPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): @@ -23,7 +40,9 @@ class StepVideoPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): config = config.to_dict() self.guidance_scale = config.get("guidance_scale", 7.5) self.num_frames, self.height, self.width = config.get("input_size", [204, 768, 768]) - self.generator = torch.Generator().manual_seed(config.get("seed", 42)) + self.generator = torch.Generator(device="npu") + self.motion_score = config.get("motion_score", 1.0) + self.model_type = config.get("model_type", "t2v") @staticmethod def apply_template(text, template): @@ -34,10 +53,81 @@ class StepVideoPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): else: raise NotImplementedError(f"Not Support text type {type(text)}") + def check_inputs(self, num_frames, width, height): + num_frames = max(num_frames // 17 * 17, 1) + width = max(width // 16 * 16, 16) + height = max(height // 16 * 16, 16) + return num_frames, width, height + + def resize_to_desired_aspect_ratio(self, video, aspect_size): + ## video is in shape [f, c, h, w] + height, width = video.shape[-2:] + + aspect_ratio = [w / h for h, w in aspect_size] + # # resize + aspect_ratio_fact = width / height + bucket_idx = np.argmin(np.abs(aspect_ratio_fact - np.array(aspect_ratio))) + aspect_ratio = aspect_ratio[bucket_idx] + target_size_height, target_size_width = aspect_size[bucket_idx] + + if aspect_ratio_fact < aspect_ratio: + scale = target_size_width / width + else: + scale = target_size_height / height + + width_scale = int(round(width * scale)) + height_scale = int(round(height * scale)) + + # # crop + delta_h = height_scale - target_size_height + delta_w = width_scale - target_size_width + if delta_w < 0 or delta_h < 0: + raise ValueError("the delta_w and delta_h must be greater than or equal to 0.") + + top = delta_h // 2 + left = delta_w // 2 + + ## resize image and crop + resize_crop_transform = transforms.Compose([ + transforms.Resize((height_scale, width_scale)), + lambda x: transforms.functional.crop(x, top, left, target_size_height, target_size_width), + ]) + + video = torch.stack([resize_crop_transform(frame.contiguous()) for frame in video], dim=0) + return video + + def prepare_image_latents(self, params: ImageLatentsConfig): + num_frames, width, height = self.check_inputs(params.num_frames, params.width, params.height) + img_tensor = transforms.ToTensor()(params.image[0].convert('RGB')) * 2 - 1 + img_tensor = self.resize_to_desired_aspect_ratio(img_tensor[None], aspect_size=[(height, width)])[None] + img_tensor = img_tensor.to(params.dtype).to(params.device) + img_emb = self.vae.encode(img_tensor).repeat(params.batch_size, 1, 1, 1, 1).to(params.device) + + padding_tensor = torch.zeros((params.batch_size, max(num_frames // 17 * 3, 1) - 1, params.num_channels_latents, + int(height) // 16, + int(width) // 16,), device=params.device) + condition_hidden_states = torch.cat([img_emb, padding_tensor], dim=1) + + condition_hidden_states = condition_hidden_states.repeat(2, 1, 1, 1, 1) # for CFG + return condition_hidden_states.to(params.dtype) + + def get_positive_magic(self): + if self.model_type == "t2v": + return POSITIVE_MAGIC_T2V + else: + return POSITIVE_MAGIC_I2V + + def get_negative_magic(self): + if self.model_type == "t2v": + return NEGATIVE_MAGIC_T2V + else: + return NEGATIVE_MAGIC_I2V + @torch.no_grad() def __call__( self, prompt: Optional[Union[str, List[str]]] = None, + image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, @@ -53,15 +143,16 @@ class StepVideoPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): The call function to the pipeline for generation Inputs: - prompt (`str` or `List[str]`): + prompt (`str` or `List[str]`): The prompt or prompts to guide video/image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in video/image generation. Ignored when not using guidance (`guidance_scale < 1`) Returns: - video (`torch.Tensor` or `List[torch.Tensor]`) + video (`torch.Tensor` or `List[torch.Tensor]`) """ - args = get_args() + height = self.height + width = self.width # 1. Check inputs self.text_prompt_checks( @@ -71,6 +162,9 @@ class StepVideoPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): negative_prompt_embeds ) + if image is not None: + self.image_prompt_checks(image) + # 2. Default call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -83,16 +177,16 @@ class StepVideoPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): # 3. Encode input prompt if negative_prompt is None or negative_prompt == "": - negative_prompt = NEGATIVE_MAGIC + negative_prompt = self.get_negative_magic() if not isinstance(negative_prompt, str): raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}") if isinstance(prompt, str): - prompt = [prompt + POSITIVE_MAGIC] + prompt = [prompt + self.get_positive_magic()] elif isinstance(prompt, list) or isinstance(prompt, tuple): - prompt = [one_text + POSITIVE_MAGIC for one_text in prompt] + prompt = [one_text + self.get_positive_magic() for one_text in prompt] else: - raise NotImplementedError(f"Not Support text type {type(prompt)}") + raise NotImplementedError(f"Not Support text type {type(prompt)}") # Text Encoder load to device self.text_encoders.to(device) @@ -135,11 +229,27 @@ class StepVideoPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): self.text_encoders.to("cpu") torch.cuda.empty_cache() + # prepare image_latents for i2v task + if image is not None: + params = ImageLatentsConfig( + image=image, + batch_size=batch_size, + num_channels_latents=self.predict_model.in_channels, + height=height, + width=width, + num_frames=self.num_frames, + device=device, + dtype=prompt_embeds.dtype + ) + image_latents = self.prepare_image_latents(params=params) + else: + image_latents = None + # 4. Prepare latents latent_channels = self.predict_model.in_channels shape = ( batch_size, - self.num_frames // self.vae.frame_len * self.vae.latent_len, + max(1, self.num_frames // self.vae.frame_len * self.vae.latent_len), latent_channels, int(self.height) // 16, int(self.width) // 16, @@ -161,8 +271,12 @@ class StepVideoPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): device=device, do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=self.guidance_scale, - model_kwargs={"prompt": [prompt_embeds, clip_embedding], "prompt_mask": [prompt_mask, clip_mask]} + model_kwargs={"prompt": [prompt_embeds, clip_embedding], "prompt_mask": [prompt_mask, clip_mask], + "motion_score": self.motion_score, "image_latents": image_latents} ) + # predict model offload to 'cpu' + self.predict_model.to("cpu") + torch.cuda.empty_cache() # 6. Decode video = self.decode_latents(latents.to(self.vae.dtype))# b t c h w