From b8ff7aa4584e89e1368df555f12878ce1fe9d02d Mon Sep 17 00:00:00 2001 From: Logan Date: Mon, 10 Feb 2025 17:02:35 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E8=BF=81=E7=A7=BBSD=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=88=B0MultiModal=E8=B7=AF=E5=BE=84=E4=B8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../CogVideoX-5B}/README.md | 0 .../CogVideoX-5B}/cogvideox_5b/__init__.py | 0 .../cogvideox_5b/models/__init__.py | 0 .../cogvideox_5b/models/activations.py | 0 .../cogvideox_5b/models/attention.py | 0 .../models/attention_processor.py | 0 .../cogvideox_5b/models/embeddings.py | 0 .../cogvideox_5b/models/normalization.py | 0 .../models/transformers/__init__.py | 0 .../transformers/cogvideox_transformer_3d.py | 0 .../cogvideox_5b/pipelines/__init__.py | 0 .../pipelines/pipeline_cogvideox.py | 0 .../cogvideox_5b/pipelines/pipeline_output.py | 0 .../cogvideox_5b/utils/__init__.py | 0 .../cogvideox_5b/utils/parallel_mgr.py | 0 .../cogvideox_5b/utils/parallel_state.py | 0 .../CogVideoX-5B}/inference.py | 0 .../CogVideoX-5B}/pta_plugin/CMakeLists.txt | 0 .../CogVideoX-5B}/pta_plugin/build.sh | 0 .../pta_plugin/extension_ops.cpp | 0 .../pta_plugin/test/test_rope.py | 0 .../CogVideoX-5B}/requirements.txt | 0 .../CogView3-Plus-3B}/README.md | 0 .../cogview3plus/__init__.py | 0 .../cogview3plus/layers/__init__.py | 4 +- .../cogview3plus/layers/embeddings.py | 606 +++---- .../cogview3plus/layers/linear.py | 94 +- .../cogview3plus/layers/normalization.py | 352 ++-- .../cogview3plus/models/__init__.py | 2 +- .../cogview3plus/models/activations.py | 324 ++-- .../cogview3plus/models/attention.py | 172 +- .../models/attention_processor.py | 694 ++++---- .../cogview3plus/models/model_load_utils.py | 82 +- .../cogview3plus/models/modeling_utils.py | 1542 ++++++++--------- .../models/transformer_cogview3plus.py | 794 ++++----- .../cogview3plus/pipeline/__init__.py | 0 .../pipeline/pipeline_cogview3plus.py | 676 ++++---- .../cogview3plus/pipeline/pipeline_output.py | 40 +- .../cogview3plus/schedulers/__init__.py | 2 +- .../schedulers/scheduling_ddim_cogvideox.py | 550 +++--- .../schedulers/scheduling_utils.py | 224 +-- .../cogview3plus/vae/__init__.py | 0 .../inference_cogview3plus.py | 0 .../CogView3-Plus-3B}/requirents.txt | 0 .../Flux.1-DEV/FLUX1dev/__init__.py | 0 .../Flux.1-DEV/FLUX1dev/layers/__init__.py | 0 .../FLUX1dev/layers/attention_processor.py | 0 .../Flux.1-DEV/FLUX1dev/layers/embedding.py | 0 .../Flux.1-DEV/FLUX1dev/models/__init__.py | 0 .../FLUX1dev/models/modeling_utils.py | 0 .../FLUX1dev/models/transformer_flux.py | 0 .../Flux.1-DEV/FLUX1dev/pipeline/__init__.py | 0 .../FLUX1dev/pipeline/pipeline_flux.py | 0 .../Flux.1-DEV/README.md | 0 .../Flux.1-DEV/inference_flux.py | 0 .../Flux.1-DEV/prompts.txt | 0 .../Flux.1-DEV/requirements.txt | 0 .../HunyuanDiT}/README.md | 0 .../HunyuanDiT}/hydit/__init__.py | 0 .../HunyuanDiT}/hydit/layers/__init__.py | 0 .../HunyuanDiT}/hydit/layers/activation.py | 0 .../HunyuanDiT}/hydit/layers/attention.py | 0 .../HunyuanDiT}/hydit/layers/embedding.py | 0 .../HunyuanDiT}/hydit/layers/mlp.py | 0 .../HunyuanDiT}/hydit/layers/norm.py | 0 .../HunyuanDiT}/hydit/layers/poolers.py | 0 .../HunyuanDiT}/hydit/models/__init__.py | 0 .../HunyuanDiT}/hydit/models/hydit.py | 0 .../hydit/models/model_load_utils.py | 0 .../HunyuanDiT}/hydit/models/model_utils.py | 0 .../HunyuanDiT}/hydit/pipeline/__init__.py | 0 .../hydit/pipeline/hydit_pipeline.py | 0 .../HunyuanDiT}/hydit/schedulers/__init__.py | 0 .../HunyuanDiT}/hydit/schedulers/ddpm.py | 0 .../HunyuanDiT}/hydit/utils/__init__.py | 0 .../HunyuanDiT}/hydit/utils/file_utils.py | 0 .../HunyuanDiT}/hydit/utils/utils.py | 0 .../HunyuanDiT}/inference_hydit.py | 0 .../HunyuanDiT}/lora/__init__.py | 0 .../HunyuanDiT}/lora/hydit_lora.py | 0 .../HunyuanDiT}/prompts/example_prompts.txt | 0 .../HunyuanDiT}/requirents.txt | 0 .../OpenSora-v1.2}/README.md | 0 .../OpenSora-v1.2}/inference_opensora12.py | 0 .../OpenSora-v1.2}/opensora/__init__.py | 0 .../OpenSora-v1.2}/opensora/layer/__init__.py | 0 .../opensora/layer/activation.py | 0 .../opensora/layer/attention.py | 0 .../OpenSora-v1.2}/opensora/layer/comm.py | 0 .../OpenSora-v1.2}/opensora/layer/conv.py | 0 .../OpenSora-v1.2}/opensora/layer/embdding.py | 0 .../OpenSora-v1.2}/opensora/layer/mlp.py | 0 .../OpenSora-v1.2}/opensora/layer/norm.py | 0 .../opensora/layer/parallel_mgr.py | 0 .../OpenSora-v1.2}/opensora/layer/utils.py | 0 .../opensora/pipeline/__init__.py | 0 .../opensora/pipeline/compile_pipe.py | 0 .../opensora/pipeline/open_sora_pipeline.py | 0 .../opensora/pipeline/pipeline_utils.py | 0 .../opensora/schedulers/__init__.py | 0 .../opensora/schedulers/rectified_flow.py | 0 .../opensora/stdit3/__init__.py | 0 .../OpenSora-v1.2}/opensora/stdit3/stdit3.py | 0 .../OpenSora-v1.2}/opensora/utils/__init__.py | 0 .../opensora/utils/patch_utils.py | 0 .../OpenSora-v1.2}/opensora/utils/utils.py | 0 .../opensora/vae/VideoAutoencoder.py | 0 .../OpenSora-v1.2}/opensora/vae/__init__.py | 0 .../opensora/vae/vae_temporal.py | 0 .../OpenSora-v1.2}/prompts/t2v_sora.txt | 0 .../OpenSora-v1.2}/requirents.txt | 0 MindIE/MultiModal/OpenSoraPlan-v1.3/README.md | 182 ++ .../inference_opensoraplan13.py | 0 .../OpenSoraPlan-v1.3}/layers/__init__.py | 0 .../OpenSoraPlan-v1.3}/layers/activation.py | 0 .../OpenSoraPlan-v1.3}/layers/attention.py | 0 .../OpenSoraPlan-v1.3}/layers/cache_mgr.py | 0 .../OpenSoraPlan-v1.3}/layers/conv.py | 0 .../OpenSoraPlan-v1.3}/layers/linear.py | 0 .../OpenSoraPlan-v1.3}/layers/mlp.py | 0 .../OpenSoraPlan-v1.3}/layers/norm.py | 0 .../OpenSoraPlan-v1.3}/layers/sampling.py | 0 .../OpenSoraPlan-v1.3}/layers/utils.py | 0 .../OpenSoraPlan-v1.3}/layers/vresnet.py | 0 .../OpenSoraPlan-v1.3}/models/comm.py | 0 .../OpenSoraPlan-v1.3}/models/model_utils.py | 0 .../OpenSoraPlan-v1.3}/models/parallel_mgr.py | 0 .../OpenSoraPlan-v1.3}/models/t2vdit.py | 0 .../OpenSoraPlan-v1.3}/models/wfvae.py | 0 .../pipeline/open_soar_plan_pipeline.py | 0 .../pipeline/pipeline_utils.py | 0 .../OpenSoraPlan-v1.3}/utils/__init__.py | 0 .../OpenSoraPlan-v1.3}/utils/utils.py | 0 133 files changed, 3261 insertions(+), 3079 deletions(-) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/README.md (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/models/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/models/activations.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/models/attention.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/models/attention_processor.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/models/embeddings.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/models/normalization.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/models/transformers/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/pipelines/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/pipelines/pipeline_cogvideox.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/pipelines/pipeline_output.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/utils/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/utils/parallel_mgr.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/cogvideox_5b/utils/parallel_state.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/inference.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/pta_plugin/CMakeLists.txt (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/pta_plugin/build.sh (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/pta_plugin/extension_ops.cpp (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/pta_plugin/test/test_rope.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/CogVideoX-5b => MultiModal/CogVideoX-5B}/requirements.txt (100%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/README.md (100%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/layers/__init__.py (99%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/layers/embeddings.py (97%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/layers/linear.py (97%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/layers/normalization.py (97%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/models/__init__.py (99%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/models/activations.py (97%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/models/attention.py (97%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/models/attention_processor.py (97%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/models/model_load_utils.py (97%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/models/modeling_utils.py (97%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/models/transformer_cogview3plus.py (97%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/pipeline/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/pipeline/pipeline_cogview3plus.py (97%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/pipeline/pipeline_output.py (96%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/schedulers/__init__.py (99%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/schedulers/scheduling_ddim_cogvideox.py (97%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/schedulers/scheduling_utils.py (97%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/cogview3plus/vae/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/inference_cogview3plus.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/cogview3 => MultiModal/CogView3-Plus-3B}/requirents.txt (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/FLUX1dev/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/FLUX1dev/layers/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/FLUX1dev/layers/attention_processor.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/FLUX1dev/layers/embedding.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/FLUX1dev/models/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/FLUX1dev/models/modeling_utils.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/FLUX1dev/models/transformer_flux.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/FLUX1dev/pipeline/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/FLUX1dev/pipeline/pipeline_flux.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/README.md (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/inference_flux.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/prompts.txt (100%) rename MindIE/{MindIE-Torch/built-in/foundation => MultiModal}/Flux.1-DEV/requirements.txt (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/README.md (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/layers/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/layers/activation.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/layers/attention.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/layers/embedding.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/layers/mlp.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/layers/norm.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/layers/poolers.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/models/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/models/hydit.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/models/model_load_utils.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/models/model_utils.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/pipeline/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/pipeline/hydit_pipeline.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/schedulers/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/schedulers/ddpm.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/utils/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/utils/file_utils.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/hydit/utils/utils.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/inference_hydit.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/lora/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/lora/hydit_lora.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/prompts/example_prompts.txt (100%) rename MindIE/{MindIE-Torch/built-in/foundation/hunyuan_dit => MultiModal/HunyuanDiT}/requirents.txt (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/README.md (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/inference_opensora12.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/layer/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/layer/activation.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/layer/attention.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/layer/comm.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/layer/conv.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/layer/embdding.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/layer/mlp.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/layer/norm.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/layer/parallel_mgr.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/layer/utils.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/pipeline/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/pipeline/compile_pipe.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/pipeline/open_sora_pipeline.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/pipeline/pipeline_utils.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/schedulers/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/schedulers/rectified_flow.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/stdit3/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/stdit3/stdit3.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/utils/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/utils/patch_utils.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/utils/utils.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/vae/VideoAutoencoder.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/vae/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/opensora/vae/vae_temporal.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/prompts/t2v_sora.txt (100%) rename MindIE/{MindIE-Torch/built-in/foundation/opensora1.2 => MultiModal/OpenSora-v1.2}/requirents.txt (100%) create mode 100644 MindIE/MultiModal/OpenSoraPlan-v1.3/README.md rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/inference_opensoraplan13.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/layers/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/layers/activation.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/layers/attention.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/layers/cache_mgr.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/layers/conv.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/layers/linear.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/layers/mlp.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/layers/norm.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/layers/sampling.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/layers/utils.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/layers/vresnet.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/models/comm.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/models/model_utils.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/models/parallel_mgr.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/models/t2vdit.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/models/wfvae.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/pipeline/open_soar_plan_pipeline.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/pipeline/pipeline_utils.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/utils/__init__.py (100%) rename MindIE/{MindIE-Torch/built-in/foundation/open_sora_planv1_3 => MultiModal/OpenSoraPlan-v1.3}/utils/utils.py (100%) diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/README.md b/MindIE/MultiModal/CogVideoX-5B/README.md similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/README.md rename to MindIE/MultiModal/CogVideoX-5B/README.md diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/__init__.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/__init__.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/__init__.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/__init__.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/activations.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/activations.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/activations.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/activations.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/attention.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/attention.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/attention_processor.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/attention_processor.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/embeddings.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/embeddings.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/embeddings.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/embeddings.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/normalization.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/normalization.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/normalization.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/normalization.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/__init__.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/transformers/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/__init__.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/transformers/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/__init__.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/pipelines/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/__init__.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/pipelines/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/pipelines/pipeline_cogvideox.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/pipelines/pipeline_cogvideox.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_output.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/pipelines/pipeline_output.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_output.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/pipelines/pipeline_output.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/__init__.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/utils/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/__init__.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/utils/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_mgr.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/utils/parallel_mgr.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_mgr.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/utils/parallel_mgr.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_state.py b/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/utils/parallel_state.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_state.py rename to MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/utils/parallel_state.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py b/MindIE/MultiModal/CogVideoX-5B/inference.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py rename to MindIE/MultiModal/CogVideoX-5B/inference.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/CMakeLists.txt b/MindIE/MultiModal/CogVideoX-5B/pta_plugin/CMakeLists.txt similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/CMakeLists.txt rename to MindIE/MultiModal/CogVideoX-5B/pta_plugin/CMakeLists.txt diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/build.sh b/MindIE/MultiModal/CogVideoX-5B/pta_plugin/build.sh similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/build.sh rename to MindIE/MultiModal/CogVideoX-5B/pta_plugin/build.sh diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/extension_ops.cpp b/MindIE/MultiModal/CogVideoX-5B/pta_plugin/extension_ops.cpp similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/extension_ops.cpp rename to MindIE/MultiModal/CogVideoX-5B/pta_plugin/extension_ops.cpp diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/test/test_rope.py b/MindIE/MultiModal/CogVideoX-5B/pta_plugin/test/test_rope.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/test/test_rope.py rename to MindIE/MultiModal/CogVideoX-5B/pta_plugin/test/test_rope.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/requirements.txt b/MindIE/MultiModal/CogVideoX-5B/requirements.txt similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/requirements.txt rename to MindIE/MultiModal/CogVideoX-5B/requirements.txt diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md b/MindIE/MultiModal/CogView3-Plus-3B/README.md similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md rename to MindIE/MultiModal/CogView3-Plus-3B/README.md diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/__init__.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/__init__.py similarity index 99% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/__init__.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/__init__.py index 09760b9fd0..602ad432a0 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/__init__.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/__init__.py @@ -1,3 +1,3 @@ -from .normalization import CogView3PlusAdaLayerNormZeroTextImage, AdaLayerNormContinuous -from .embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed +from .normalization import CogView3PlusAdaLayerNormZeroTextImage, AdaLayerNormContinuous +from .embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed from .linear import QKVLinear \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/embeddings.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/embeddings.py similarity index 97% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/embeddings.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/embeddings.py index 129384dffc..fc2d3101eb 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/embeddings.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/embeddings.py @@ -1,304 +1,304 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Optional - -import torch -from torch import nn -from diffusers.models.activations import get_activation - - -def get_timestep_embedding( - timesteps: torch.Tensor, - embedding_dim: int, - flip_sin_to_cos: bool = False, - downscale_freq_shift: float = 1, - max_period: int = 10000, -): - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) - exponent = exponent / (half_dim - downscale_freq_shift) - - emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] - - # concat sine and cosine embeddings - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) - - # flip sine and cosine embeddings - if flip_sin_to_cos: - emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def get_2d_sincos_pos_embed( - embed_dim, - grid_size, - interpolation_scale=1.0, - base_size=16, -): - if isinstance(grid_size, int): - grid_size = (grid_size, grid_size) - - grid_h = ( - torch.arange(grid_size[0], dtype=torch.float32) - / (grid_size[0] / base_size) - / interpolation_scale - ) - grid_w = ( - torch.arange(grid_size[1], dtype=torch.float32) - / (grid_size[1] / base_size) - / interpolation_scale - ) - grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first - grid = torch.stack(grid, dim=0) - - grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - r""" - This function generates 2D sinusoidal positional embeddings from a grid. - - Args: - embed_dim (`int`): The embedding dimension. - grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`. - - Returns: - `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - This function generates 1D positional embeddings from a grid. - - Args: - embed_dim (`int`): The embedding dimension `D` - pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)` - - Returns: - `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`. - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = torch.outer(pos, omega) # (M, D/2), outer product - - emb_sin = torch.sin(out) # (M, D/2) - emb_cos = torch.cos(out) # (M, D/2) - - emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D) - return emb - - -class Timesteps(nn.Module): - def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - ) - return t_emb - - -class TimestepEmbedding(nn.Module): - def __init__( - self, - in_channels: int, - time_embed_dim: int, - act_fn: str = "silu", - out_dim: int = None, - post_act_fn: Optional[str] = None, - cond_proj_dim=None, - sample_proj_bias=True, - ): - super().__init__() - - self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) - - if cond_proj_dim is not None: - self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) - else: - self.cond_proj = None - - self.act = get_activation(act_fn) - - if out_dim is not None: - time_embed_dim_out = out_dim - else: - time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) - - if post_act_fn is None: - self.post_act = None - else: - self.post_act = get_activation(post_act_fn) - - def forward(self, sample, condition=None): - if condition is not None: - sample = sample + self.cond_proj(condition) - sample = self.linear_1(sample) - - if self.act is not None: - sample = self.act(sample) - - sample = self.linear_2(sample) - - if self.post_act is not None: - sample = self.post_act(sample) - return sample - - -class PixArtAlphaTextProjection(nn.Module): - """ - Projects caption embeddings. Also handles dropout for classifier-free guidance. - """ - - def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): - super().__init__() - if out_features is None: - out_features = hidden_size - self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) - if act_fn == "gelu_tanh": - self.act_1 = nn.GELU(approximate="tanh") - elif act_fn == "silu": - self.act_1 = nn.SiLU() - else: - raise ValueError(f"Unknown activation function: {act_fn}") - self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) - - def forward(self, caption): - hidden_states = self.linear_1(caption) - hidden_states = self.act_1(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states - - -class CogView3CombinedTimestepSizeEmbeddings(nn.Module): - def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): - super().__init__() - - self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0) - self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim) - self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") - - def forward( - self, - timestep: torch.Tensor, - original_size: torch.Tensor, - target_size: torch.Tensor, - crop_coords: torch.Tensor, - hidden_dtype: torch.dtype, - ) -> torch.Tensor: - timesteps_proj = self.time_proj(timestep) - - original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1) - crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1) - target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1) - - condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1) - - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) - condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) - - conditioning = timesteps_emb + condition_emb - return conditioning - - -class CogView3PlusPatchEmbed(nn.Module): - def __init__( - self, - in_channels: int = 16, - hidden_size: int = 2560, - patch_size: int = 2, - text_hidden_size: int = 4096, - pos_embed_max_size: int = 128, - ): - super().__init__() - self.in_channels = in_channels - self.hidden_size = hidden_size - self.patch_size = patch_size - self.text_hidden_size = text_hidden_size - self.pos_embed_max_size = pos_embed_max_size - # Linear projection for image patches - self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) - - # Linear projection for text embeddings - self.text_proj = nn.Linear(text_hidden_size, hidden_size) - - pos_embed = get_2d_sincos_pos_embed( - hidden_size, pos_embed_max_size, base_size=pos_embed_max_size - ) - pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size) - self.register_buffer("pos_embed", pos_embed.float(), persistent=False) - - def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, channel, height, width = hidden_states.shape - - if height % self.patch_size != 0 or width % self.patch_size != 0: - raise ValueError("Height and width must be divisible by patch size") - - height = height // self.patch_size - width = width // self.patch_size - hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size) - hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous() - hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size) - - # Project the patches - hidden_states = self.proj(hidden_states) - encoder_hidden_states = self.text_proj(encoder_hidden_states) - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - # Calculate text_length - text_length = encoder_hidden_states.shape[1] - - image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1) - text_pos_embed = torch.zeros( - (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device - ) - pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...] - +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +from torch import nn +from diffusers.models.activations import get_activation + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + max_period: int = 10000, +): + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + interpolation_scale=1.0, + base_size=16, +): + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = ( + torch.arange(grid_size[0], dtype=torch.float32) + / (grid_size[0] / base_size) + / interpolation_scale + ) + grid_w = ( + torch.arange(grid_size[1], dtype=torch.float32) + / (grid_size[1] / base_size) + / interpolation_scale + ) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first + grid = torch.stack(grid, dim=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + r""" + This function generates 2D sinusoidal positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension. + grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`. + + Returns: + `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + This function generates 1D positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension `D` + pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)` + + Returns: + `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`. + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.outer(pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D) + return emb + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ + + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class CogView3CombinedTimestepSizeEmbeddings(nn.Module): + def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): + super().__init__() + + self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim) + self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward( + self, + timestep: torch.Tensor, + original_size: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + + original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1) + crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1) + target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1) + + condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1) + + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + + conditioning = timesteps_emb + condition_emb + return conditioning + + +class CogView3PlusPatchEmbed(nn.Module): + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + text_hidden_size: int = 4096, + pos_embed_max_size: int = 128, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.text_hidden_size = text_hidden_size + self.pos_embed_max_size = pos_embed_max_size + # Linear projection for image patches + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + + # Linear projection for text embeddings + self.text_proj = nn.Linear(text_hidden_size, hidden_size) + + pos_embed = get_2d_sincos_pos_embed( + hidden_size, pos_embed_max_size, base_size=pos_embed_max_size + ) + pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size) + self.register_buffer("pos_embed", pos_embed.float(), persistent=False) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + + if height % self.patch_size != 0 or width % self.patch_size != 0: + raise ValueError("Height and width must be divisible by patch size") + + height = height // self.patch_size + width = width // self.patch_size + hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous() + hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size) + + # Project the patches + hidden_states = self.proj(hidden_states) + encoder_hidden_states = self.text_proj(encoder_hidden_states) + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # Calculate text_length + text_length = encoder_hidden_states.shape[1] + + image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1) + text_pos_embed = torch.zeros( + (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device + ) + pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...] + return (hidden_states + pos_embed).to(hidden_states.dtype) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/linear.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/linear.py similarity index 97% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/linear.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/linear.py index 57fe8d55dc..d242d17c2e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/linear.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/linear.py @@ -1,48 +1,48 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn - - -class QKVLinear(nn.Module): - def __init__(self, attention_dim, hidden_size, qkv_bias=True, device=None, dtype=None): - super(QKVLinear, self).__init__() - self.attention_dim = attention_dim - self.hidden_size = hidden_size - self.qkv_bias = qkv_bias - - factory_kwargs = {"device": device, "dtype": dtype} - - self.weight = nn.Parameter(torch.empty([self.attention_dim, 3 * self.hidden_size], **factory_kwargs)) - if self.qkv_bias: - self.bias = nn.Parameter(torch.empty([3 * self.hidden_size], **factory_kwargs)) - - def forward(self, hidden_states): - - if not self.qkv_bias: - qkv = torch.matmul(hidden_states, self.weight) - else: - qkv = torch.addmm( - self.bias, - hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), - self.weight, - beta=1, - alpha=1 - ) - +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class QKVLinear(nn.Module): + def __init__(self, attention_dim, hidden_size, qkv_bias=True, device=None, dtype=None): + super(QKVLinear, self).__init__() + self.attention_dim = attention_dim + self.hidden_size = hidden_size + self.qkv_bias = qkv_bias + + factory_kwargs = {"device": device, "dtype": dtype} + + self.weight = nn.Parameter(torch.empty([self.attention_dim, 3 * self.hidden_size], **factory_kwargs)) + if self.qkv_bias: + self.bias = nn.Parameter(torch.empty([3 * self.hidden_size], **factory_kwargs)) + + def forward(self, hidden_states): + + if not self.qkv_bias: + qkv = torch.matmul(hidden_states, self.weight) + else: + qkv = torch.addmm( + self.bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.weight, + beta=1, + alpha=1 + ) + return qkv \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/normalization.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/normalization.py similarity index 97% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/normalization.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/normalization.py index 1ec0a5b15c..c12b70c9b1 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/normalization.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/normalization.py @@ -1,177 +1,177 @@ -# coding=utf-8 -# Copyright 2024 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numbers -from typing import Optional, Tuple -from dataclasses import dataclass - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class RMSNorm(nn.Module): - def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False): - super().__init__() - - self.eps = eps - self.elementwise_affine = elementwise_affine - - if isinstance(dim, numbers.Integral): - dim = (dim,) - - self.dim = torch.Size(dim) - - self.weight = None - self.bias = None - - if elementwise_affine: - self.weight = nn.Parameter(torch.ones(dim)) - if bias: - self.bias = nn.Parameter(torch.zeros(dim)) - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - - if self.weight is not None: - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - hidden_states = hidden_states * self.weight - if self.bias is not None: - hidden_states = hidden_states + self.bias - else: - hidden_states = hidden_states.to(input_dtype) - - return hidden_states - - -@dataclass -class ChunkParam: - gate_msa: torch.Tensor - shift_mlp: torch.Tensor - scale_mlp: torch.Tensor - gate_mlp: torch.Tensor - context: torch.Tensor - c_gate_msa: torch.Tensor - c_shift_mlp: torch.Tensor - c_scale_mlp: torch.Tensor - c_gate_mlp: torch.Tensor - - -class CogView3PlusAdaLayerNormZeroTextImage(nn.Module): - r""" - Norm layer adaptive layer norm zero (adaLN-Zero). - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - num_embeddings (`int`): The size of the embeddings dictionary. - """ - - def __init__(self, embedding_dim: int, dim: int): - super().__init__() - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) - self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) - self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) - - def forward( - self, - x: torch.Tensor, - context: torch.Tensor, - emb: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - emb = self.linear(self.silu(emb)) - ( - shift_msa, - scale_msa, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - c_shift_msa, - c_scale_msa, - c_gate_msa, - c_shift_mlp, - c_scale_mlp, - c_gate_mlp, - ) = emb.chunk(12, dim=1) - normed_x = self.norm_x(x) - normed_context = self.norm_c(context) - x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None] - context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None] - return x, ChunkParam( - gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp - ) - - -class FP32LayerNorm(nn.LayerNorm): - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - origin_dtype = inputs.dtype - return F.layer_norm( - inputs.float(), - self.normalized_shape, - self.weight.float() if self.weight is not None else None, - self.bias.float() if self.bias is not None else None, - self.eps, - ).to(origin_dtype) - - -class LpNorm(nn.Module): - def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12): - super().__init__() - - self.p = p - self.dim = dim - self.eps = eps - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps) - - -class AdaLayerNormContinuous(nn.Module): - def __init__( - self, - embedding_dim: int, - conditioning_embedding_dim: int, - # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters - # because the output is immediately scaled and shifted by the projected conditioning embeddings. - # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. - # However, this is how it was implemented in the original code, and it's rather likely you should - # set `elementwise_affine` to False. - elementwise_affine=True, - eps=1e-5, - bias=True, - norm_type="layer_norm", - ): - super().__init__() - self.silu = nn.SiLU() - self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) - if norm_type == "layer_norm": - self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) - elif norm_type == "rms_norm": - self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) - else: - raise ValueError(f"unknown norm_type {norm_type}") - - def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: - # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) - emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) - scale, shift = torch.chunk(emb, 2, dim=1) - x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers +from typing import Optional, Tuple +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False): + super().__init__() + + self.eps = eps + self.elementwise_affine = elementwise_affine + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + self.weight = None + self.bias = None + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + if bias: + self.bias = nn.Parameter(torch.zeros(dim)) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + if self.bias is not None: + hidden_states = hidden_states + self.bias + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + +@dataclass +class ChunkParam: + gate_msa: torch.Tensor + shift_mlp: torch.Tensor + scale_mlp: torch.Tensor + gate_mlp: torch.Tensor + context: torch.Tensor + c_gate_msa: torch.Tensor + c_shift_mlp: torch.Tensor + c_scale_mlp: torch.Tensor + c_gate_mlp: torch.Tensor + + +class CogView3PlusAdaLayerNormZeroTextImage(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, dim: int): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) + self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + c_shift_msa, + c_scale_msa, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + normed_x = self.norm_x(x) + normed_context = self.norm_c(context) + x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None] + context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None] + return x, ChunkParam( + gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp + ) + + +class FP32LayerNorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm( + inputs.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +class LpNorm(nn.Module): + def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12): + super().__init__() + + self.p = p + self.dim = dim + self.eps = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps) + + +class AdaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/__init__.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/__init__.py similarity index 99% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/__init__.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/__init__.py index b3c595bfcc..ae8f24f59a 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/__init__.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/__init__.py @@ -1,2 +1,2 @@ -from .transformer_cogview3plus import CogView3PlusTransformer2DModel +from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .modeling_utils import ModelMixin \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/activations.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/activations.py similarity index 97% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/activations.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/activations.py index 5bb3783ae4..b7d7cec29d 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/activations.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/activations.py @@ -1,163 +1,163 @@ -# coding=utf-8 -# Copyright 2024 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn.functional as F -from torch import nn - -from diffusers.utils import deprecate -from diffusers.utils.import_utils import is_torch_npu_available - -if is_torch_npu_available(): - import torch_npu - -ACTIVATION_FUNCTIONS = { - "swish": nn.SiLU(), - "silu": nn.SiLU(), - "mish": nn.Mish(), - "gelu": nn.GELU(), - "relu": nn.ReLU(), -} - - -def get_activation(act_fn: str) -> nn.Module: - """Helper function to get activation function from string. - - Args: - act_fn (str): Name of activation function. - - Returns: - nn.Module: Activation function. - """ - - act_fn = act_fn.lower() - if act_fn in ACTIVATION_FUNCTIONS: - return ACTIVATION_FUNCTIONS[act_fn] - else: - raise ValueError(f"Unsupported activation function: {act_fn}") - - -class FP32SiLU(nn.Module): - r""" - SiLU activation function with input upcasted to torch.float32. - """ - - def __init__(self): - super().__init__() - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - return F.silu(inputs.float(), inplace=False).to(inputs.dtype) - - -class GELU(nn.Module): - r""" - GELU activation function with tanh approximation support with `approximate="tanh"`. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out, bias=bias) - self.approximate = approximate - - def gelu(self, gate: torch.Tensor) -> torch.Tensor: - return F.gelu(gate, approximate=self.approximate) - - def forward(self, hidden_states): - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) - return hidden_states - - -class GEGLU(nn.Module): - r""" - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - def __init__(self, dim_in: int, dim_out: int, bias: bool = True): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) - - def gelu(self, gate: torch.Tensor) -> torch.Tensor: - return F.gelu(gate) - - def forward(self, hidden_states, *args, **kwargs): - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - hidden_states = self.proj(hidden_states) - if is_torch_npu_available(): - # using torch_npu.npu_geglu can run faster and save memory on NPU. - return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] - else: - hidden_states, gate = hidden_states.chunk(2, dim=-1) - return hidden_states * self.gelu(gate) - - -class SwiGLU(nn.Module): - r""" - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - def __init__(self, dim_in: int, dim_out: int, bias: bool = True): - super().__init__() - - self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) - self.activation = nn.SiLU() - - def forward(self, hidden_states): - hidden_states = self.proj(hidden_states) - hidden_states, gate = hidden_states.chunk(2, dim=-1) - return hidden_states * self.activation(gate) - - -class ApproximateGELU(nn.Module): - r""" - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - def __init__(self, dim_in: int, dim_out: int, bias: bool = True): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out, bias=bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.proj(x) - return x * torch.sigmoid(1.702 * x) - - -class LinearActivation(nn.Module): - def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"): - super().__init__() - - self.proj = nn.Linear(dim_in, dim_out, bias=bias) - self.activation = get_activation(activation) - - def forward(self, hidden_states): - hidden_states = self.proj(hidden_states) +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate +from diffusers.utils.import_utils import is_torch_npu_available + +if is_torch_npu_available(): + import torch_npu + +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU(), +} + + +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class FP32SiLU(nn.Module): + r""" + SiLU activation function with input upcasted to torch.float32. + """ + + def __init__(self): + super().__init__() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return F.silu(inputs.float(), inplace=False).to(inputs.dtype) + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + return F.gelu(gate, approximate=self.approximate) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + return F.gelu(gate) + + def forward(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + hidden_states = self.proj(hidden_states) + if is_torch_npu_available(): + # using torch_npu.npu_geglu can run faster and save memory on NPU. + return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] + else: + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class SwiGLU(nn.Module): + r""" + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.activation = nn.SiLU() + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.activation(gate) + + +class ApproximateGELU(nn.Module): + r""" + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class LinearActivation(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"): + super().__init__() + + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.activation = get_activation(activation) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) return self.activation(hidden_states) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/attention.py similarity index 97% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/attention.py index a7a559ff2f..946d829c6c 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/attention.py @@ -1,87 +1,87 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -from torch import nn - -from diffusers.utils import deprecate, logging -from .activations import GEGLU, GELU, ApproximateGELU, LinearActivation, SwiGLU - - -logger = logging.get_logger(__name__) - - -class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - inner_dim=None, - bias: bool = True, - ): - super().__init__() - if inner_dim is None: - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim, bias=bias) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim, bias=bias) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim, bias=bias) - elif activation_fn == "swiglu": - act_fn = SwiGLU(dim, inner_dim, bias=bias) - elif activation_fn == "linear-silu": - act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - for module in self.net: - hidden_states = module(hidden_states) +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from torch import nn + +from diffusers.utils import deprecate, logging +from .activations import GEGLU, GELU, ApproximateGELU, LinearActivation, SwiGLU + + +logger = logging.get_logger(__name__) + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "swiglu": + act_fn = SwiGLU(dim, inner_dim, bias=bias) + elif activation_fn == "linear-silu": + act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for module in self.net: + hidden_states = module(hidden_states) return hidden_states \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention_processor.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/attention_processor.py similarity index 97% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention_processor.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/attention_processor.py index c197a989b7..d36e9265a3 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention_processor.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/attention_processor.py @@ -1,348 +1,348 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Optional - -import torch -import torch.nn.functional as F -from torch import nn -import torch_npu - -from diffusers.utils import logging -from diffusers.utils.torch_utils import maybe_allow_in_graph - -from ..layers import QKVLinear - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@maybe_allow_in_graph -class Attention(nn.Module): - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - kv_heads: Optional[int] = None, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - cross_attention_norm: Optional[str] = None, - cross_attention_norm_num_groups: int = 32, - qk_norm: Optional[str] = None, - added_kv_proj_dim: Optional[int] = None, - added_proj_bias: Optional[bool] = True, - norm_num_groups: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - processor: Optional["AttnProcessor"] = None, - out_dim: int = None, - out_context_dim: int = None, - context_pre_only=None, - pre_only=False, - elementwise_affine: bool = True, - is_causal: bool = False, - ): - super().__init__() - - # To prevent circular import. - from ..layers.normalization import FP32LayerNorm, LpNorm, RMSNorm - - self.inner_dim = out_dim if out_dim is not None else dim_head * heads - self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads - self.query_dim = query_dim - self.use_bias = bias - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - self.rescale_output_factor = rescale_output_factor - self.residual_connection = residual_connection - self.dropout = dropout - self.fused_projections = False - self.out_dim = out_dim if out_dim is not None else query_dim - self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim - self.context_pre_only = context_pre_only - self.pre_only = pre_only - self.is_causal = is_causal - - # we make use of this private variable to know whether this class is loaded - # with an deprecated state dict so that we can convert it on the fly - self._from_deprecated_attn_block = _from_deprecated_attn_block - - self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - - self.heads = out_dim // dim_head if out_dim is not None else heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self.sliceable_head_dim = heads - - self.added_kv_proj_dim = added_kv_proj_dim - self.only_cross_attention = only_cross_attention - - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) - else: - self.group_norm = None - - self.spatial_norm = None - - if qk_norm is None: - self.norm_q = None - self.norm_k = None - elif qk_norm == "layer_norm": - self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) - self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) - elif qk_norm == "fp32_layer_norm": - self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - elif qk_norm == "layer_norm_across_heads": - # Lumina applies qk norm across all heads - self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) - self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) - elif qk_norm == "rms_norm": - self.norm_q = RMSNorm(dim_head, eps=eps) - self.norm_k = RMSNorm(dim_head, eps=eps) - elif qk_norm == "rms_norm_across_heads": - # LTX applies qk norm across all heads - self.norm_q = RMSNorm(dim_head * heads, eps=eps) - self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps) - elif qk_norm == "l2": - self.norm_q = LpNorm(p=2, dim=-1, eps=eps) - self.norm_k = LpNorm(p=2, dim=-1, eps=eps) - else: - raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'") - - if cross_attention_norm is None: - self.norm_cross = None - elif cross_attention_norm == "layer_norm": - self.norm_cross = nn.LayerNorm(self.cross_attention_dim) - elif cross_attention_norm == "group_norm": - if self.added_kv_proj_dim is not None: - # The given `encoder_hidden_states` are initially of shape - # (batch_size, seq_len, added_kv_proj_dim) before being projected - # to (batch_size, seq_len, cross_attention_dim). The norm is applied - # before the projection, so we need to use `added_kv_proj_dim` as - # the number of channels for the group norm. - norm_cross_num_channels = added_kv_proj_dim - else: - norm_cross_num_channels = self.cross_attention_dim - - self.norm_cross = nn.GroupNorm( - num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True - ) - else: - raise ValueError( - f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" - ) - - self.to_qkv = QKVLinear(self.inner_dim, query_dim) - - self.added_proj_bias = added_proj_bias - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) - if self.context_pre_only is not None: - self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - else: - self.add_q_proj = None - self.add_k_proj = None - self.add_v_proj = None - - if not self.pre_only: - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - else: - self.to_out = None - - if self.context_pre_only is not None and not self.context_pre_only: - self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) - else: - self.to_add_out = None - - if qk_norm is not None and added_kv_proj_dim is not None: - if qk_norm == "fp32_layer_norm": - self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - elif qk_norm == "rms_norm": - self.norm_added_q = RMSNorm(dim_head, eps=eps) - self.norm_added_k = RMSNorm(dim_head, eps=eps) - else: - raise ValueError( - f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" - ) - else: - self.norm_added_q = None - self.norm_added_k = None - - self.set_processor(processor) - - def set_processor(self, processor: "AttnProcessor") -> None: - r""" - Set the attention processor to use. - - Args: - processor (`AttnProcessor`): - The attention processor to use. - """ - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") - self._modules.pop("processor") - - self.processor = processor - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **cross_attention_kwargs, - ) -> torch.Tensor: - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} - unused_kwargs = [ - k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters - ] - if len(unused_kwargs) > 0: - logger.warning( - f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." - ) - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} - - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - def prepare_attention_mask( - self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 - ) -> torch.Tensor: - head_size = self.heads - if attention_mask is None: - return attention_mask - - current_length: int = attention_mask.shape[-1] - if current_length != target_length: - if attention_mask.device.type == "mps": - padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) - padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) - attention_mask = torch.cat([attention_mask, padding], dim=2) - else: - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) - elif out_dim == 4: - attention_mask = attention_mask.unsqueeze(1) - attention_mask = attention_mask.repeat_interleave(head_size, dim=1) - - return attention_mask - - -class CogVideoXAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on - query and key vectors, but does not include spatial normalization. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - B, S, _ = hidden_states.shape - qkv = attn.to_qkv(hidden_states) - inner_dim = qkv.shape[-1] // 3 - head_dim = inner_dim // attn.heads - qkv_shape = (B, S, 3, attn.heads, head_dim) - query, key, value = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4).contiguous().unbind(0) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - B, N, S, D = query.shape - dim = 48 - pad_shape = [B, N, S, D] - pad_shape[-1] = dim - pad_shape[-1] - pad = torch.zeros(pad_shape, dtype=query.dtype, device=query.device) - query = torch.cat([query, pad], dim=-1) - key = torch.cat([key, pad], dim=-1) - value = torch.cat([value, pad], dim=-1) - hidden_states = torch_npu.npu_prompt_flash_attention( - query, - key, - value, - input_layout='BNSD', - scale_value=D**-0.5, - pre_tokens=65535, - next_tokens=65535, - num_heads=N - ) - hidden_states = hidden_states[:, :, :, :D] - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 - ) +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +import torch_npu + +from diffusers.utils import logging +from diffusers.utils.torch_utils import maybe_allow_in_graph + +from ..layers import QKVLinear + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + out_context_dim: int = None, + context_pre_only=None, + pre_only=False, + elementwise_affine: bool = True, + is_causal: bool = False, + ): + super().__init__() + + # To prevent circular import. + from ..layers.normalization import FP32LayerNorm, LpNorm, RMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.is_causal = is_causal + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "layer_norm_across_heads": + # Lumina applies qk norm across all heads + self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) + self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim_head * heads, eps=eps) + self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "l2": + self.norm_q = LpNorm(p=2, dim=-1, eps=eps) + self.norm_k = LpNorm(p=2, dim=-1, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_qkv = QKVLinear(self.inner_dim, query_dim) + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) + else: + self.norm_added_q = None + self.norm_added_k = None + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + +class CogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + B, S, _ = hidden_states.shape + qkv = attn.to_qkv(hidden_states) + inner_dim = qkv.shape[-1] // 3 + head_dim = inner_dim // attn.heads + qkv_shape = (B, S, 3, attn.heads, head_dim) + query, key, value = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4).contiguous().unbind(0) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + B, N, S, D = query.shape + dim = 48 + pad_shape = [B, N, S, D] + pad_shape[-1] = dim - pad_shape[-1] + pad = torch.zeros(pad_shape, dtype=query.dtype, device=query.device) + query = torch.cat([query, pad], dim=-1) + key = torch.cat([key, pad], dim=-1) + value = torch.cat([value, pad], dim=-1) + hidden_states = torch_npu.npu_prompt_flash_attention( + query, + key, + value, + input_layout='BNSD', + scale_value=D**-0.5, + pre_tokens=65535, + next_tokens=65535, + num_heads=N + ) + hidden_states = hidden_states[:, :, :, :D] + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) return hidden_states, encoder_hidden_states \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/model_load_utils.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/model_load_utils.py similarity index 97% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/model_load_utils.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/model_load_utils.py index 34a4625283..1257aad309 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/model_load_utils.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/model_load_utils.py @@ -1,42 +1,42 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import torch -import safetensors.torch - - -SAFETENSORS_EXTENSION = "safetensors" -EMA_STATE_DICT = "ema_state_dict" -STATE_DICT = "state_dict" -CPU = "cpu" - - -def load_state_dict_sd(model_path): - name = os.path.basename(model_path).split('.')[-1] # get weights name - if name.endswith("ckpt"): - weight = torch.load(model_path, map_location=CPU) - if (EMA_STATE_DICT in weight): - weight = weight[EMA_STATE_DICT] - weight = {key.replace("module.", ""): value for key, value in weight.items()} - elif STATE_DICT in weight: - weight = weight[STATE_DICT] - return weight - elif name == SAFETENSORS_EXTENSION: # diffuser model use same name - return safetensors.torch.load_file(model_path, device=CPU) # first load on cpu - else: - # to support hf shard model weights +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import safetensors.torch + + +SAFETENSORS_EXTENSION = "safetensors" +EMA_STATE_DICT = "ema_state_dict" +STATE_DICT = "state_dict" +CPU = "cpu" + + +def load_state_dict_sd(model_path): + name = os.path.basename(model_path).split('.')[-1] # get weights name + if name.endswith("ckpt"): + weight = torch.load(model_path, map_location=CPU) + if (EMA_STATE_DICT in weight): + weight = weight[EMA_STATE_DICT] + weight = {key.replace("module.", ""): value for key, value in weight.items()} + elif STATE_DICT in weight: + weight = weight[STATE_DICT] + return weight + elif name == SAFETENSORS_EXTENSION: # diffuser model use same name + return safetensors.torch.load_file(model_path, device=CPU) # first load on cpu + else: + # to support hf shard model weights return torch.load(model_path, map_location=CPU) # first load on cpu \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/modeling_utils.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/modeling_utils.py similarity index 97% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/modeling_utils.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/modeling_utils.py index aa8e33daaa..fddf0ade3f 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/modeling_utils.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/modeling_utils.py @@ -1,771 +1,771 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import itertools -import json -import os -import re -from collections import OrderedDict -from functools import wraps -from typing import Any, List, Optional, Tuple, Union - -import torch -from huggingface_hub.utils import validate_hf_hub_args -from torch import Tensor, nn - -from diffusers import __version__ -from diffusers.quantizers import DiffusersAutoQuantizer -from diffusers.quantizers.quantization_config import QuantizationMethod -from diffusers.utils import ( - CONFIG_NAME, - FLAX_WEIGHTS_NAME, - SAFETENSORS_WEIGHTS_NAME, - WEIGHTS_NAME, - _add_variant, - _get_checkpoint_shard_files, - _get_model_file, - deprecate, - is_accelerate_available, - is_bitsandbytes_version, - logging, -) -from diffusers.utils.hub_utils import PushToHubMixin -from diffusers.models.model_loading_utils import ( - _fetch_index_file, - _fetch_index_file_legacy, - _load_state_dict_into_model, - _merge_sharded_checkpoints, - load_model_dict_into_meta, - load_state_dict, -) - - -logger = logging.get_logger(__name__) - - -_LOW_CPU_MEM_USAGE_DEFAULT = True - - -if is_accelerate_available(): - import accelerate - - -def get_parameter_device(parameter: torch.nn.Module) -> torch.device: - try: - parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) - return next(parameters_and_buffers).device - except StopIteration: - # For torch.nn.DataParallel compatibility in PyTorch 1.5 - - def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].device - - -def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: - """ - Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. - """ - last_dtype = None - for param in parameter.parameters(): - last_dtype = param.dtype - if param.is_floating_point(): - return param.dtype - - for buffer in parameter.buffers(): - last_dtype = buffer.dtype - if buffer.is_floating_point(): - return buffer.dtype - - if last_dtype is not None: - # if no floating dtype was found return whatever the first dtype is - return last_dtype - - # For nn.DataParallel compatibility in PyTorch > 1.5 - def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - last_tuple = None - for current_tuple in gen: - last_tuple = current_tuple - if current_tuple[1].is_floating_point(): - return current_tuple[1].dtype - - if last_tuple is not None: - # fallback to the last dtype - return last_tuple[1].dtype - - -class ModelMixin(torch.nn.Module, PushToHubMixin): - config_name = CONFIG_NAME - _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] - _supports_gradient_checkpointing = False - _keys_to_ignore_on_load_unexpected = None - _no_split_modules = None - _keep_in_fp32_modules = None - - def __init__(self): - super().__init__() - - def __getattr__(self, name: str) -> Any: - - is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) - is_attribute = name in self.__dict__ - - if is_in_config and not is_attribute: - deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." - deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) - return self._internal_dict[name] - - return super().__getattr__(name) - - @classmethod - @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - cache_dir = kwargs.pop("cache_dir", None) - ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) - force_download = kwargs.pop("force_download", False) - from_flax = kwargs.pop("from_flax", False) - proxies = kwargs.pop("proxies", None) - output_loading_info = kwargs.pop("output_loading_info", False) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - subfolder = kwargs.pop("subfolder", None) - device_map = kwargs.pop("device_map", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - variant = kwargs.pop("variant", None) - use_safetensors = kwargs.pop("use_safetensors", None) - quantization_config = kwargs.pop("quantization_config", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - if low_cpu_mem_usage and not is_accelerate_available(): - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if device_map is not None and not is_accelerate_available(): - raise NotImplementedError( - "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" - " `device_map=None`. You can install accelerate with `pip install accelerate`." - ) - - if low_cpu_mem_usage is False and device_map is not None: - raise ValueError( - f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" - " dispatching. Please make sure to set `low_cpu_mem_usage=True`." - ) - - if isinstance(device_map, torch.device): - device_map = {"": device_map} - elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: - try: - device_map = {"": torch.device(device_map)} - except RuntimeError as e: - raise ValueError( - "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " - f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." - ) from e - elif isinstance(device_map, int): - if device_map < 0: - raise ValueError( - "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " - ) - else: - device_map = {"": device_map} - - if device_map is not None: - if low_cpu_mem_usage is None: - low_cpu_mem_usage = True - elif not low_cpu_mem_usage: - raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") - - # Load config if we don't provide a configuration - config_path = pretrained_model_name_or_path - - user_agent = { - "diffusers": __version__, - "file_type": "model", - "framework": "pytorch", - } - - # load config - config, unused_kwargs, commit_hash = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - return_commit_hash=True, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - **kwargs, - ) - # no in-place modification of the original config. - config = copy.deepcopy(config) - - # determine initial quantization config. - ####################################### - pre_quantized = "quantization_config" in config and config["quantization_config"] is not None - if pre_quantized or quantization_config is not None: - if pre_quantized: - config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs( - config["quantization_config"], quantization_config - ) - else: - config["quantization_config"] = quantization_config - hf_quantizer = DiffusersAutoQuantizer.from_config( - config["quantization_config"], pre_quantized=pre_quantized - ) - else: - hf_quantizer = None - - if hf_quantizer is not None: - is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" - if is_bnb_quantization_method and device_map is not None: - raise NotImplementedError( - "Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future." - ) - - hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) - torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) - - # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` - user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value - - # Force-set to `True` for more mem efficiency - if low_cpu_mem_usage is None: - low_cpu_mem_usage = True - logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.") - elif not low_cpu_mem_usage: - raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.") - - # Check if `_keep_in_fp32_modules` is not None - use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( - (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") - ) - if use_keep_in_fp32_modules: - keep_in_fp32_modules = cls._keep_in_fp32_modules - if not isinstance(keep_in_fp32_modules, list): - keep_in_fp32_modules = [keep_in_fp32_modules] - - if low_cpu_mem_usage is None: - low_cpu_mem_usage = True - logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.") - elif not low_cpu_mem_usage: - raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.") - else: - keep_in_fp32_modules = [] - ####################################### - - # Determine if we're loading from a directory of sharded checkpoints. - is_sharded = False - index_file = None - is_local = os.path.isdir(pretrained_model_name_or_path) - index_file_kwargs = { - "is_local": is_local, - "pretrained_model_name_or_path": pretrained_model_name_or_path, - "subfolder": subfolder or "", - "use_safetensors": use_safetensors, - "cache_dir": cache_dir, - "variant": variant, - "force_download": force_download, - "proxies": proxies, - "local_files_only": local_files_only, - "token": token, - "revision": revision, - "user_agent": user_agent, - "commit_hash": commit_hash, - } - index_file = _fetch_index_file(**index_file_kwargs) - # In case the index file was not found we still have to consider the legacy format. - # this becomes applicable when the variant is not None. - if variant is not None and (index_file is None or not os.path.exists(index_file)): - index_file = _fetch_index_file_legacy(**index_file_kwargs) - if index_file is not None and index_file.is_file(): - is_sharded = True - - if is_sharded and from_flax: - raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") - - # load model - model_file = None - if from_flax: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=FLAX_WEIGHTS_NAME, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - model = cls.from_config(config, **unused_kwargs) - - # Convert the weights - from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model - - model = load_flax_checkpoint_in_pytorch_model(model, model_file) - else: - if is_sharded: - sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( - pretrained_model_name_or_path, - index_file, - cache_dir=cache_dir, - proxies=proxies, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder or "", - ) - if hf_quantizer is not None and is_bnb_quantization_method: - model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) - logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") - is_sharded = False - - elif use_safetensors and not is_sharded: - try: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - - except IOError as e: - logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") - if not allow_pickle: - raise - logger.warning( - "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." - ) - - if model_file is None and not is_sharded: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - - if low_cpu_mem_usage: - # Instantiate model with empty weights - with accelerate.init_empty_weights(): - model = cls.from_config(config, **unused_kwargs) - - if hf_quantizer is not None: - hf_quantizer.preprocess_model( - model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules - ) - - # if device_map is None, load the state dict and move the params from meta device to the cpu - if device_map is None and not is_sharded: - # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. - # It would error out during the `validate_environment()` call above in the absence of cuda. - if hf_quantizer is None: - param_device = "cpu" - else: - param_device = torch.device(torch.cuda.current_device()) - state_dict = load_state_dict(model_file, variant=variant) - model._convert_deprecated_attention_blocks(state_dict) - - # move the params from meta device to cpu - missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) - if hf_quantizer is not None: - missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") - if len(missing_keys) > 0: - raise ValueError( - f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" - f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" - " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" - " those weights or else make sure your checkpoint file is correct." - ) - - unexpected_keys = load_model_dict_into_meta( - model, - state_dict, - device=param_device, - dtype=torch_dtype, - model_name_or_path=pretrained_model_name_or_path, - hf_quantizer=hf_quantizer, - keep_in_fp32_modules=keep_in_fp32_modules, - ) - - if cls._keys_to_ignore_on_load_unexpected is not None: - for pat in cls._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) - - else: - weights_path = index_file - with open(index_file) as f: - index = json.loads(f.read()) - if "weight_map" in index: - index = index["weight_map"] - weights_path = sorted(list(set(index.values()))) - weights_path = [os.path.join(pretrained_model_name_or_path, f) for f in weights_path] - - model = cls._load_model(model, weights_path, is_sharded) - - loading_info = { - "missing_keys": [], - "unexpected_keys": [], - "mismatched_keys": [], - "error_msgs": [], - } - else: - model = cls.from_config(config, **unused_kwargs) - - state_dict = load_state_dict(model_file, variant=variant) - model._convert_deprecated_attention_blocks(state_dict) - - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - ) - - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } - - if hf_quantizer is not None: - hf_quantizer.postprocess_model(model) - model.hf_quantizer = hf_quantizer - - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." - ) - # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will - # completely lose the effectivity of `use_keep_in_fp32_modules`. - elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules: - model = model.to(torch_dtype) - - if hf_quantizer is not None: - # We also make sure to purge `_pre_quantization_dtype` when we serialize - # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable. - model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype) - else: - model.register_to_config(_name_or_path=pretrained_model_name_or_path) - - # Set model in evaluation mode to deactivate DropOut modules by default - model.eval() - if output_loading_info: - return model, loading_info - - return model - - @classmethod - def _load_model(cls, model, weights_path, is_sharded): - if not is_sharded: - state_dict = load_state_dict(weights_path) - model.load_weights(state_dict) - else: - need_key = set(model.state_dict().keys()) - state_dict = {} - cache = {} - for weight_file in weights_path: - state_dict = load_state_dict(weight_file) - state_dict.update(cache) - loadkey_cache = model.load_weights(state_dict, is_sharded) - if loadkey_cache : - if isinstance(loadkey_cache, tuple): - loaded_keys, cache = loadkey_cache - else: - loaded_keys = loadkey_cache - need_key = need_key.symmetric_difference(set(loaded_keys)) - - if len(need_key) > 0: - raise ValueError(f"The weight miss key: {need_key}") - return model - - def load_weights(self, state_dict, shard=False): - with torch.no_grad(): - if not shard: - self.load_state_dict(state_dict) - return {} - else: - self.load_state_dict(state_dict, strict=False, assign=True) - return state_dict.keys() - - # Adapted from `transformers`. - @wraps(torch.nn.Module.cuda) - def cuda(self, *args, **kwargs): - # Checks if the model has been loaded in 4-bit or 8-bit with BNB - if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - if getattr(self, "is_loaded_in_8bit", False): - raise ValueError( - "Calling `cuda()` is not supported for `8-bit` quantized models. " - " Please use the model as it is, since the model has already been set to the correct devices." - ) - elif is_bitsandbytes_version("<", "0.43.2"): - raise ValueError( - "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " - f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." - ) - return super().cuda(*args, **kwargs) - - # Adapted from `transformers`. - @wraps(torch.nn.Module.to) - def to(self, *args, **kwargs): - dtype_present_in_args = "dtype" in kwargs - - if not dtype_present_in_args: - for arg in args: - if isinstance(arg, torch.dtype): - dtype_present_in_args = True - break - - if getattr(self, "is_quantized", False): - if dtype_present_in_args: - raise ValueError( - "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please " - "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`" - ) - - if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - if getattr(self, "is_loaded_in_8bit", False): - raise ValueError( - "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" - " model has already been set to the correct devices and casted to the correct `dtype`." - ) - elif is_bitsandbytes_version("<", "0.43.2"): - raise ValueError( - "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " - f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." - ) - return super().to(*args, **kwargs) - - # Taken from `transformers`. - def half(self, *args): - # Checks if the model is quantized - if getattr(self, "is_quantized", False): - raise ValueError( - "`.half()` is not supported for quantized model. Please use the model as it is, since the" - " model has already been cast to the correct `dtype`." - ) - else: - return super().half(*args) - - # Taken from `transformers`. - def float(self, *args): - # Checks if the model is quantized - if getattr(self, "is_quantized", False): - raise ValueError( - "`.float()` is not supported for quantized model. Please use the model as it is, since the" - " model has already been cast to the correct `dtype`." - ) - else: - return super().float(*args) - - @classmethod - def _load_pretrained_model( - cls, - model, - state_dict: OrderedDict, - pretrained_model_name_or_path: Union[str, os.PathLike], - ignore_mismatched_sizes: bool = False, - ): - # Retrieve missing & unexpected_keys - model_state_dict = model.state_dict() - loaded_keys = list(state_dict.keys()) - - expected_keys = list(model_state_dict.keys()) - - original_loaded_keys = loaded_keys - - missing_keys = list(set(expected_keys) - set(loaded_keys)) - unexpected_keys = list(set(loaded_keys) - set(expected_keys)) - - # Make sure we are able to load base models as well as derived models (with heads) - model_to_load = model - - def _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - ignore_mismatched_sizes, - ): - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape - ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - return mismatched_keys - - if state_dict is not None: - # Whole checkpoint - mismatched_keys = _find_mismatched_keys( - state_dict, - model_state_dict, - original_loaded_keys, - ignore_mismatched_sizes, - ) - error_msgs = _load_state_dict_into_model(model_to_load, state_dict) - - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - if "size mismatch" in error_msg: - error_msg += ( - "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." - ) - raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" - f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" - f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" - " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" - " BertForPreTraining model).\n- This IS NOT expected if you are initializing" - f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" - " identical (initializing a BertForSequenceClassification model from a" - " BertForSequenceClassification model)." - ) - else: - logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" - " TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - elif len(mismatched_keys) == 0: - logger.info( - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" - f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" - f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" - " without further training." - ) - if len(mismatched_keys) > 0: - mismatched_warning = "\n".join( - [ - f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, shape1, shape2 in mismatched_keys - ] - ) - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" - f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" - " able to use it for predictions and inference." - ) - - return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs - - @property - def device(self) -> torch.device: - return get_parameter_device(self) - - @property - def dtype(self) -> torch.dtype: - return get_parameter_dtype(self) - - def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: - deprecated_attention_block_paths = [] - - def recursive_find_attn_block(name, module): - if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: - deprecated_attention_block_paths.append(name) - - for sub_name, sub_module in module.named_children(): - sub_name = sub_name if name == "" else f"{name}.{sub_name}" - recursive_find_attn_block(sub_name, sub_module) - - recursive_find_attn_block("", self) - - for path in deprecated_attention_block_paths: - # group_norm path stays the same - - # query -> to_q - if f"{path}.query.weight" in state_dict: - state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") - if f"{path}.query.bias" in state_dict: - state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") - - # key -> to_k - if f"{path}.key.weight" in state_dict: - state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") - if f"{path}.key.bias" in state_dict: - state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") - - # value -> to_v - if f"{path}.value.weight" in state_dict: - state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") - if f"{path}.value.bias" in state_dict: - state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") - - # proj_attn -> to_out.0 - if f"{path}.proj_attn.weight" in state_dict: - state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") - if f"{path}.proj_attn.bias" in state_dict: - state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import json +import os +import re +from collections import OrderedDict +from functools import wraps +from typing import Any, List, Optional, Tuple, Union + +import torch +from huggingface_hub.utils import validate_hf_hub_args +from torch import Tensor, nn + +from diffusers import __version__ +from diffusers.quantizers import DiffusersAutoQuantizer +from diffusers.quantizers.quantization_config import QuantizationMethod +from diffusers.utils import ( + CONFIG_NAME, + FLAX_WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_checkpoint_shard_files, + _get_model_file, + deprecate, + is_accelerate_available, + is_bitsandbytes_version, + logging, +) +from diffusers.utils.hub_utils import PushToHubMixin +from diffusers.models.model_loading_utils import ( + _fetch_index_file, + _fetch_index_file_legacy, + _load_state_dict_into_model, + _merge_sharded_checkpoints, + load_model_dict_into_meta, + load_state_dict, +) + + +logger = logging.get_logger(__name__) + + +_LOW_CPU_MEM_USAGE_DEFAULT = True + + +if is_accelerate_available(): + import accelerate + + +def get_parameter_device(parameter: torch.nn.Module) -> torch.device: + try: + parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) + return next(parameters_and_buffers).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: + """ + Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. + """ + last_dtype = None + for param in parameter.parameters(): + last_dtype = param.dtype + if param.is_floating_point(): + return param.dtype + + for buffer in parameter.buffers(): + last_dtype = buffer.dtype + if buffer.is_floating_point(): + return buffer.dtype + + if last_dtype is not None: + # if no floating dtype was found return whatever the first dtype is + return last_dtype + + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + last_tuple = None + for current_tuple in gen: + last_tuple = current_tuple + if current_tuple[1].is_floating_point(): + return current_tuple[1].dtype + + if last_tuple is not None: + # fallback to the last dtype + return last_tuple[1].dtype + + +class ModelMixin(torch.nn.Module, PushToHubMixin): + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False + _keys_to_ignore_on_load_unexpected = None + _no_split_modules = None + _keep_in_fp32_modules = None + + def __init__(self): + super().__init__() + + def __getattr__(self, name: str) -> Any: + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) + return self._internal_dict[name] + + return super().__getattr__(name) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + cache_dir = kwargs.pop("cache_dir", None) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + quantization_config = kwargs.pop("quantization_config", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + if isinstance(device_map, torch.device): + device_map = {"": device_map} + elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + try: + device_map = {"": torch.device(device_map)} + except RuntimeError as e: + raise ValueError( + "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " + f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." + ) from e + elif isinstance(device_map, int): + if device_map < 0: + raise ValueError( + "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " + ) + else: + device_map = {"": device_map} + + if device_map is not None: + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + elif not low_cpu_mem_usage: + raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + **kwargs, + ) + # no in-place modification of the original config. + config = copy.deepcopy(config) + + # determine initial quantization config. + ####################################### + pre_quantized = "quantization_config" in config and config["quantization_config"] is not None + if pre_quantized or quantization_config is not None: + if pre_quantized: + config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs( + config["quantization_config"], quantization_config + ) + else: + config["quantization_config"] = quantization_config + hf_quantizer = DiffusersAutoQuantizer.from_config( + config["quantization_config"], pre_quantized=pre_quantized + ) + else: + hf_quantizer = None + + if hf_quantizer is not None: + is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" + if is_bnb_quantization_method and device_map is not None: + raise NotImplementedError( + "Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future." + ) + + hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) + torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + + # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` + user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value + + # Force-set to `True` for more mem efficiency + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.") + elif not low_cpu_mem_usage: + raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.") + + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] + + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.") + elif not low_cpu_mem_usage: + raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.") + else: + keep_in_fp32_modules = [] + ####################################### + + # Determine if we're loading from a directory of sharded checkpoints. + is_sharded = False + index_file = None + is_local = os.path.isdir(pretrained_model_name_or_path) + index_file_kwargs = { + "is_local": is_local, + "pretrained_model_name_or_path": pretrained_model_name_or_path, + "subfolder": subfolder or "", + "use_safetensors": use_safetensors, + "cache_dir": cache_dir, + "variant": variant, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "revision": revision, + "user_agent": user_agent, + "commit_hash": commit_hash, + } + index_file = _fetch_index_file(**index_file_kwargs) + # In case the index file was not found we still have to consider the legacy format. + # this becomes applicable when the variant is not None. + if variant is not None and (index_file is None or not os.path.exists(index_file)): + index_file = _fetch_index_file_legacy(**index_file_kwargs) + if index_file is not None and index_file.is_file(): + is_sharded = True + + if is_sharded and from_flax: + raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") + + # load model + model_file = None + if from_flax: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + model = cls.from_config(config, **unused_kwargs) + + # Convert the weights + from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, model_file) + else: + if is_sharded: + sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + ) + if hf_quantizer is not None and is_bnb_quantization_method: + model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) + logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") + is_sharded = False + + elif use_safetensors and not is_sharded: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") + if not allow_pickle: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if model_file is None and not is_sharded: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **unused_kwargs) + + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + ) + + # if device_map is None, load the state dict and move the params from meta device to the cpu + if device_map is None and not is_sharded: + # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. + # It would error out during the `validate_environment()` call above in the absence of cuda. + if hf_quantizer is None: + param_device = "cpu" + else: + param_device = torch.device(torch.cuda.current_device()) + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if hf_quantizer is not None: + missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_name_or_path, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + else: + weights_path = index_file + with open(index_file) as f: + index = json.loads(f.read()) + if "weight_map" in index: + index = index["weight_map"] + weights_path = sorted(list(set(index.values()))) + weights_path = [os.path.join(pretrained_model_name_or_path, f) for f in weights_path] + + model = cls._load_model(model, weights_path, is_sharded) + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + model = cls.from_config(config, **unused_kwargs) + + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will + # completely lose the effectivity of `use_keep_in_fp32_modules`. + elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules: + model = model.to(torch_dtype) + + if hf_quantizer is not None: + # We also make sure to purge `_pre_quantization_dtype` when we serialize + # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable. + model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype) + else: + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_model(cls, model, weights_path, is_sharded): + if not is_sharded: + state_dict = load_state_dict(weights_path) + model.load_weights(state_dict) + else: + need_key = set(model.state_dict().keys()) + state_dict = {} + cache = {} + for weight_file in weights_path: + state_dict = load_state_dict(weight_file) + state_dict.update(cache) + loadkey_cache = model.load_weights(state_dict, is_sharded) + if loadkey_cache : + if isinstance(loadkey_cache, tuple): + loaded_keys, cache = loadkey_cache + else: + loaded_keys = loadkey_cache + need_key = need_key.symmetric_difference(set(loaded_keys)) + + if len(need_key) > 0: + raise ValueError(f"The weight miss key: {need_key}") + return model + + def load_weights(self, state_dict, shard=False): + with torch.no_grad(): + if not shard: + self.load_state_dict(state_dict) + return {} + else: + self.load_state_dict(state_dict, strict=False, assign=True) + return state_dict.keys() + + # Adapted from `transformers`. + @wraps(torch.nn.Module.cuda) + def cuda(self, *args, **kwargs): + # Checks if the model has been loaded in 4-bit or 8-bit with BNB + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "Calling `cuda()` is not supported for `8-bit` quantized models. " + " Please use the model as it is, since the model has already been set to the correct devices." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + return super().cuda(*args, **kwargs) + + # Adapted from `transformers`. + @wraps(torch.nn.Module.to) + def to(self, *args, **kwargs): + dtype_present_in_args = "dtype" in kwargs + + if not dtype_present_in_args: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + + if getattr(self, "is_quantized", False): + if dtype_present_in_args: + raise ValueError( + "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please " + "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`" + ) + + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + return super().to(*args, **kwargs) + + # Taken from `transformers`. + def half(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.half()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been cast to the correct `dtype`." + ) + else: + return super().half(*args) + + # Taken from `transformers`. + def float(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.float()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been cast to the correct `dtype`." + ) + else: + return super().float(*args) + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict: OrderedDict, + pretrained_model_name_or_path: Union[str, os.PathLike], + ignore_mismatched_sizes: bool = False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @property + def device(self) -> torch.device: + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + return get_parameter_dtype(self) + + def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: + deprecated_attention_block_paths = [] + + def recursive_find_attn_block(name, module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_paths.append(name) + + for sub_name, sub_module in module.named_children(): + sub_name = sub_name if name == "" else f"{name}.{sub_name}" + recursive_find_attn_block(sub_name, sub_module) + + recursive_find_attn_block("", self) + + for path in deprecated_attention_block_paths: + # group_norm path stays the same + + # query -> to_q + if f"{path}.query.weight" in state_dict: + state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") + if f"{path}.query.bias" in state_dict: + state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") + + # key -> to_k + if f"{path}.key.weight" in state_dict: + state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") + if f"{path}.key.bias" in state_dict: + state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") + + # value -> to_v + if f"{path}.value.weight" in state_dict: + state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") + if f"{path}.value.bias" in state_dict: + state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") + + # proj_attn -> to_out.0 + if f"{path}.proj_attn.weight" in state_dict: + state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") + if f"{path}.proj_attn.bias" in state_dict: + state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/transformer_cogview3plus.py similarity index 97% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/transformer_cogview3plus.py index f704e22589..37c5961586 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/transformer_cogview3plus.py @@ -1,397 +1,397 @@ -# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, Union - -import torch -import torch.nn as nn -import numpy as np - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.attention_processor import AttentionProcessor -from diffusers.utils import logging -from diffusers.models.modeling_outputs import Transformer2DModelOutput - -from .modeling_utils import ModelMixin -from .attention import FeedForward -from .attention_processor import CogVideoXAttnProcessor2_0, Attention -from ..layers import CogView3PlusAdaLayerNormZeroTextImage, AdaLayerNormContinuous -from ..layers import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class CogView3PlusTransformerBlock(nn.Module): - def __init__( - self, - dim: int = 2560, - num_attention_heads: int = 64, - attention_head_dim: int = 40, - time_embed_dim: int = 512, - ): - super().__init__() - - self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim) - - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - out_dim=dim, - bias=True, - qk_norm="layer_norm", - elementwise_affine=False, - eps=1e-6, - processor=CogVideoXAttnProcessor2_0(), - ) - - self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) - self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) - - self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - emb: torch.Tensor, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - # norm & modulate - norm_hidden_states, chunk_params = self.norm1(hidden_states, encoder_hidden_states, emb) - - gate_msa = chunk_params.gate_msa - shift_mlp = chunk_params.shift_mlp - scale_mlp = chunk_params.scale_mlp - gate_mlp = chunk_params.gate_mlp - norm_encoder_hidden_states = chunk_params.context - c_gate_msa = chunk_params.c_gate_msa - c_shift_mlp = chunk_params.c_shift_mlp - c_scale_mlp = chunk_params.c_scale_mlp - c_gate_mlp = chunk_params.c_gate_mlp - - # attention - attn_hidden_states, attn_encoder_hidden_states = self.attn1( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states - ) - - hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states - encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states - - # norm & modulate - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - - # feed-forward - norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - ff_output = self.ff(norm_hidden_states) - - hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] - - if hidden_states.dtype == torch.float16: - hidden_states = hidden_states.clip(-65504, 65504) - if encoder_hidden_states.dtype == torch.float16: - encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - return hidden_states, encoder_hidden_states - - -class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - patch_size: int = 2, - in_channels: int = 16, - num_layers: int = 30, - attention_head_dim: int = 40, - num_attention_heads: int = 64, - out_channels: int = 16, - text_embed_dim: int = 4096, - time_embed_dim: int = 512, - condition_dim: int = 256, - pos_embed_max_size: int = 128, - use_cache: bool = True, - cache_interval: int = 2, - cache_start: int = 3, - num_cache_layer: int = 13, - cache_start_steps: int = 5, - ): - super().__init__() - self.out_channels = out_channels - self.inner_dim = num_attention_heads * attention_head_dim - self.num_layers = num_layers - - # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords - # Each of these are sincos embeddings of shape 2 * condition_dim - self.pooled_projection_dim = 3 * 2 * condition_dim - - self.patch_embed = CogView3PlusPatchEmbed( - in_channels=in_channels, - hidden_size=self.inner_dim, - patch_size=patch_size, - text_hidden_size=text_embed_dim, - pos_embed_max_size=pos_embed_max_size, - ) - - self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings( - embedding_dim=time_embed_dim, - condition_dim=condition_dim, - pooled_projection_dim=self.pooled_projection_dim, - timesteps_dim=self.inner_dim, - ) - - self.transformer_blocks = nn.ModuleList( - [ - CogView3PlusTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - time_embed_dim=time_embed_dim, - ) - for _ in range(num_layers) - ] - ) - - self.norm_out = AdaLayerNormContinuous( - embedding_dim=self.inner_dim, - conditioning_embedding_dim=time_embed_dim, - elementwise_affine=False, - eps=1e-6, - ) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) - - self.gradient_checkpointing = False - - self.q_weight_cache = None - self.q_bias_cache = None - self.k_weight_cache = None - self.k_bias_cache = None - self.v_weight_cache = None - self.v_bias_cache = None - - self.use_cache = use_cache - self.cache_interval = cache_interval - self.cache_start = cache_start - self.num_cache_layer = num_cache_layer - self.cache_start_steps = cache_start_steps - - self.delta_cache = None - self.delta_encoder_cache = None - - @property - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def forward( - self, - states, - timestep: torch.LongTensor, - original_size: torch.Tensor, - target_size: torch.Tensor, - crop_coords: torch.Tensor, - ) -> Union[torch.Tensor, Transformer2DModelOutput]: - hidden_states = states[0] - encoder_hidden_states = states[1] - height, width = hidden_states.shape[-2:] - text_seq_length = encoder_hidden_states.shape[1] - - hidden_states = self.patch_embed( - hidden_states, encoder_hidden_states - ) # takes care of adding positional embeddings too. - emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype) - - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] - - hidden_states, encoder_hidden_states = self._forward_blocks(hidden_states, encoder_hidden_states, emb, states[2]) - - hidden_states = self.norm_out(hidden_states, emb) - hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels) - - # unpatchify - patch_size = self.config.patch_size - height = height // patch_size - width = width // patch_size - - hidden_states = hidden_states.reshape( - shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size) - ) - hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) - ) - - return Transformer2DModelOutput(sample=output) - - # forward blocks in range [start_idx, end_idx), then return input and output - def _forward_blocks_range(self, hidden_states, encoder_hidden_states, emb, start_idx, end_idx, **kwargs): - for _, block in enumerate(self.transformer_blocks[start_idx: end_idx]): - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - emb=emb, - ) - - return hidden_states, encoder_hidden_states - - def _forward_blocks(self, hidden_states, encoder_hidden_states, emb, t_idx): - num_blocks = len(self.transformer_blocks) - - if not self.use_cache or (t_idx < self.cache_start_steps): - hidden_states, encoder_hidden_states = self._forward_blocks_range( - hidden_states, - encoder_hidden_states, - emb, - 0, - num_blocks - ) - else: - # infer [0, cache_start) - hidden_states, encoder_hidden_states = self._forward_blocks_range( - hidden_states, - encoder_hidden_states, - emb, - 0, - self.cache_start - ) - # infer [cache_start, cache_end) - cache_end = np.minimum(self.cache_start + self.num_cache_layer, num_blocks) - hidden_states_before_cache = hidden_states.clone() - encoder_hidden_states_before_cache = encoder_hidden_states.clone() - if t_idx % self.cache_interval == (self.cache_start_steps % self.cache_interval): - hidden_states, encoder_hidden_states = self._forward_blocks_range( - hidden_states, - encoder_hidden_states, - emb, - self.cache_start, - cache_end - ) - self.delta_cache = hidden_states - hidden_states_before_cache - self.delta_encoder_cache = encoder_hidden_states - encoder_hidden_states_before_cache - else: - hidden_states = hidden_states_before_cache + self.delta_cache - encoder_hidden_states = encoder_hidden_states_before_cache + self.delta_encoder_cache - # infer [cache_end, num_blocks) - hidden_states, encoder_hidden_states = self._forward_blocks_range( - hidden_states, - encoder_hidden_states, - emb, - cache_end, - num_blocks - ) - - return hidden_states, encoder_hidden_states - - def load_weights(self, state_dict, shard=False): - with torch.no_grad(): - if not shard: - self.load_state_dict(state_dict) - return {} - else: - weights = state_dict - - for i in range(self.num_layers): - if i != 26: - q_weight = weights.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None) - q_bias = weights.pop(f"transformer_blocks.{i}.attn1.to_q.bias", None) - k_weight = weights.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None) - k_bias = weights.pop(f"transformer_blocks.{i}.attn1.to_k.bias", None) - v_weight = weights.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None) - v_bias = weights.pop(f"transformer_blocks.{i}.attn1.to_v.bias", None) - - # query, key, value的weight和bias权重存在同一个文件中,不会分开存储。 - if q_weight is not None and k_weight is not None and v_weight is not None: - qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0).transpose(0, 1).contiguous() - qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0).contiguous() - weights[f"transformer_blocks.{i}.attn1.to_qkv.weight"] = qkv_weight - weights[f"transformer_blocks.{i}.attn1.to_qkv.bias"] = qkv_bias - else: - if self.q_weight_cache is None: - self.q_weight_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None) - if self.q_bias_cache is None: - self.q_bias_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_q.bias", None) - if self.k_weight_cache is None: - self.k_weight_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None) - if self.k_bias_cache is None: - self.k_bias_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_k.bias", None) - if self.v_weight_cache is None: - self.v_weight_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None) - if self.v_bias_cache is None: - self.v_bias_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_v.bias", None) - - qk_weight_cache = self.q_weight_cache is not None and self.k_weight_cache is not None - if qk_weight_cache and self.v_weight_cache is not None: - qkv_weight = torch.cat( - [self.q_weight_cache, self.k_weight_cache, self.v_weight_cache], - dim=0 - ).transpose(0, 1).contiguous() - qkv_bias = torch.cat([self.q_bias_cache, self.k_bias_cache, self.v_bias_cache], dim=0).contiguous() - weights[f"transformer_blocks.26.attn1.to_qkv.weight"] = qkv_weight - weights[f"transformer_blocks.26.attn1.to_qkv.bias"] = qkv_bias - - self.load_state_dict(weights, strict=False, assign=True) - return weights.keys() +# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Union + +import torch +import torch.nn as nn +import numpy as np + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.utils import logging +from diffusers.models.modeling_outputs import Transformer2DModelOutput + +from .modeling_utils import ModelMixin +from .attention import FeedForward +from .attention_processor import CogVideoXAttnProcessor2_0, Attention +from ..layers import CogView3PlusAdaLayerNormZeroTextImage, AdaLayerNormContinuous +from ..layers import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CogView3PlusTransformerBlock(nn.Module): + def __init__( + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + ): + super().__init__() + + self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + out_dim=dim, + bias=True, + qk_norm="layer_norm", + elementwise_affine=False, + eps=1e-6, + processor=CogVideoXAttnProcessor2_0(), + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + emb: torch.Tensor, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate + norm_hidden_states, chunk_params = self.norm1(hidden_states, encoder_hidden_states, emb) + + gate_msa = chunk_params.gate_msa + shift_mlp = chunk_params.shift_mlp + scale_mlp = chunk_params.scale_mlp + gate_mlp = chunk_params.gate_mlp + norm_encoder_hidden_states = chunk_params.context + c_gate_msa = chunk_params.c_gate_msa + c_shift_mlp = chunk_params.c_shift_mlp + c_scale_mlp = chunk_params.c_scale_mlp + c_gate_mlp = chunk_params.c_gate_mlp + + # attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] + + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + return hidden_states, encoder_hidden_states + + +class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 30, + attention_head_dim: int = 40, + num_attention_heads: int = 64, + out_channels: int = 16, + text_embed_dim: int = 4096, + time_embed_dim: int = 512, + condition_dim: int = 256, + pos_embed_max_size: int = 128, + use_cache: bool = True, + cache_interval: int = 2, + cache_start: int = 3, + num_cache_layer: int = 13, + cache_start_steps: int = 5, + ): + super().__init__() + self.out_channels = out_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.num_layers = num_layers + + # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords + # Each of these are sincos embeddings of shape 2 * condition_dim + self.pooled_projection_dim = 3 * 2 * condition_dim + + self.patch_embed = CogView3PlusPatchEmbed( + in_channels=in_channels, + hidden_size=self.inner_dim, + patch_size=patch_size, + text_hidden_size=text_embed_dim, + pos_embed_max_size=pos_embed_max_size, + ) + + self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings( + embedding_dim=time_embed_dim, + condition_dim=condition_dim, + pooled_projection_dim=self.pooled_projection_dim, + timesteps_dim=self.inner_dim, + ) + + self.transformer_blocks = nn.ModuleList( + [ + CogView3PlusTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + embedding_dim=self.inner_dim, + conditioning_embedding_dim=time_embed_dim, + elementwise_affine=False, + eps=1e-6, + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + self.q_weight_cache = None + self.q_bias_cache = None + self.k_weight_cache = None + self.k_bias_cache = None + self.v_weight_cache = None + self.v_bias_cache = None + + self.use_cache = use_cache + self.cache_interval = cache_interval + self.cache_start = cache_start + self.num_cache_layer = num_cache_layer + self.cache_start_steps = cache_start_steps + + self.delta_cache = None + self.delta_encoder_cache = None + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + states, + timestep: torch.LongTensor, + original_size: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + hidden_states = states[0] + encoder_hidden_states = states[1] + height, width = hidden_states.shape[-2:] + text_seq_length = encoder_hidden_states.shape[1] + + hidden_states = self.patch_embed( + hidden_states, encoder_hidden_states + ) # takes care of adding positional embeddings too. + emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype) + + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + hidden_states, encoder_hidden_states = self._forward_blocks(hidden_states, encoder_hidden_states, emb, states[2]) + + hidden_states = self.norm_out(hidden_states, emb) + hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels) + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size) + ) + hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + return Transformer2DModelOutput(sample=output) + + # forward blocks in range [start_idx, end_idx), then return input and output + def _forward_blocks_range(self, hidden_states, encoder_hidden_states, emb, start_idx, end_idx, **kwargs): + for _, block in enumerate(self.transformer_blocks[start_idx: end_idx]): + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=emb, + ) + + return hidden_states, encoder_hidden_states + + def _forward_blocks(self, hidden_states, encoder_hidden_states, emb, t_idx): + num_blocks = len(self.transformer_blocks) + + if not self.use_cache or (t_idx < self.cache_start_steps): + hidden_states, encoder_hidden_states = self._forward_blocks_range( + hidden_states, + encoder_hidden_states, + emb, + 0, + num_blocks + ) + else: + # infer [0, cache_start) + hidden_states, encoder_hidden_states = self._forward_blocks_range( + hidden_states, + encoder_hidden_states, + emb, + 0, + self.cache_start + ) + # infer [cache_start, cache_end) + cache_end = np.minimum(self.cache_start + self.num_cache_layer, num_blocks) + hidden_states_before_cache = hidden_states.clone() + encoder_hidden_states_before_cache = encoder_hidden_states.clone() + if t_idx % self.cache_interval == (self.cache_start_steps % self.cache_interval): + hidden_states, encoder_hidden_states = self._forward_blocks_range( + hidden_states, + encoder_hidden_states, + emb, + self.cache_start, + cache_end + ) + self.delta_cache = hidden_states - hidden_states_before_cache + self.delta_encoder_cache = encoder_hidden_states - encoder_hidden_states_before_cache + else: + hidden_states = hidden_states_before_cache + self.delta_cache + encoder_hidden_states = encoder_hidden_states_before_cache + self.delta_encoder_cache + # infer [cache_end, num_blocks) + hidden_states, encoder_hidden_states = self._forward_blocks_range( + hidden_states, + encoder_hidden_states, + emb, + cache_end, + num_blocks + ) + + return hidden_states, encoder_hidden_states + + def load_weights(self, state_dict, shard=False): + with torch.no_grad(): + if not shard: + self.load_state_dict(state_dict) + return {} + else: + weights = state_dict + + for i in range(self.num_layers): + if i != 26: + q_weight = weights.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None) + q_bias = weights.pop(f"transformer_blocks.{i}.attn1.to_q.bias", None) + k_weight = weights.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None) + k_bias = weights.pop(f"transformer_blocks.{i}.attn1.to_k.bias", None) + v_weight = weights.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None) + v_bias = weights.pop(f"transformer_blocks.{i}.attn1.to_v.bias", None) + + # query, key, value的weight和bias权重存在同一个文件中,不会分开存储。 + if q_weight is not None and k_weight is not None and v_weight is not None: + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0).transpose(0, 1).contiguous() + qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0).contiguous() + weights[f"transformer_blocks.{i}.attn1.to_qkv.weight"] = qkv_weight + weights[f"transformer_blocks.{i}.attn1.to_qkv.bias"] = qkv_bias + else: + if self.q_weight_cache is None: + self.q_weight_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None) + if self.q_bias_cache is None: + self.q_bias_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_q.bias", None) + if self.k_weight_cache is None: + self.k_weight_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None) + if self.k_bias_cache is None: + self.k_bias_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_k.bias", None) + if self.v_weight_cache is None: + self.v_weight_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None) + if self.v_bias_cache is None: + self.v_bias_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_v.bias", None) + + qk_weight_cache = self.q_weight_cache is not None and self.k_weight_cache is not None + if qk_weight_cache and self.v_weight_cache is not None: + qkv_weight = torch.cat( + [self.q_weight_cache, self.k_weight_cache, self.v_weight_cache], + dim=0 + ).transpose(0, 1).contiguous() + qkv_bias = torch.cat([self.q_bias_cache, self.k_bias_cache, self.v_bias_cache], dim=0).contiguous() + weights[f"transformer_blocks.26.attn1.to_qkv.weight"] = qkv_weight + weights[f"transformer_blocks.26.attn1.to_qkv.bias"] = qkv_bias + + self.load_state_dict(weights, strict=False, assign=True) + return weights.keys() diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/__init__.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/pipeline/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/__init__.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/pipeline/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/pipeline/pipeline_cogview3plus.py similarity index 97% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/pipeline/pipeline_cogview3plus.py index fe2bd5cfcd..4b07df76a6 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/pipeline/pipeline_cogview3plus.py @@ -1,339 +1,339 @@ -# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import List, Optional, Tuple, Union - -import torch -from transformers import T5EncoderModel, T5Tokenizer - -from diffusers.image_processor import VaeImageProcessor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.utils import logging -from diffusers.utils.torch_utils import randn_tensor -from diffusers import AutoencoderKL - -from ..models import CogView3PlusTransformer2DModel -from ..schedulers import CogVideoXDDIMScheduler -from .pipeline_output import CogView3PipelineOutput - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class CogView3PlusPipeline(DiffusionPipeline): - _optional_components = [] - model_cpu_offload_seq = "text_encoder->transformer->vae" - - _callback_tensor_inputs = [ - "latents", - "prompt_embeds", - "negative_prompt_embeds", - ] - - def __init__( - self, - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - vae: AutoencoderKL, - transformer: CogView3PlusTransformer2DModel, - scheduler: CogVideoXDDIMScheduler, - ): - super().__init__() - - self.register_modules( - tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler - ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) - - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - max_sequence_length: int = 224, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape) - - return prompt_embeds, negative_prompt_embeds - - def prepare_latents(self, batch_size, num_channels_latents, image_size, dtype, device): - height = image_size[0] - width = image_size[1] - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - latents = randn_tensor(shape, device=device, dtype=dtype) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def prepare_extra_step_kwargs(self, generator, eta): - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - height, - width, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - image_size: Tuple[int, int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - num_images_per_prompt: int = 1, - ) -> Union[CogView3PipelineOutput, Tuple]: - if image_size is None: - height = self.transformer.config.sample_size * self.vae_scale_factor - width = self.transformer.config.sample_size * self.vae_scale_factor - else: - height = image_size[0] - width = image_size[1] - - original_size = (height, width) - target_size = (height, width) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - height, - width, - ) - self._guidance_scale = guidance_scale - self._interrupt = False - - # 2. Default call parameters - if isinstance(prompt, str): - batch_size = 1 - else: - batch_size = len(prompt) - - device = self._execution_device - - # 3. Encode input prompt - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=224, - device=device, - ) - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - - # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device) - self._num_timesteps = len(timesteps) - - # 5. Prepare latents. - latent_channels = self.transformer.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - latent_channels, - (height, width), - prompt_embeds.dtype, - device, - ) - - extra_step_kwargs = self.prepare_extra_step_kwargs(None, 0.0) - - # 7. Prepare additional timestep conditions - original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype) - target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype) - crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype) - - if self.do_classifier_free_guidance: - original_size = torch.cat([original_size, original_size]) - target_size = torch.cat([target_size, target_size]) - crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left]) - - original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1) - target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1) - crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1) - - # 8. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - - with self.progress_bar(total=num_inference_steps) as progress_bar: - # for DPM-solver++ - old_pred_original_sample = None - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) - - # predict noise model_output - noise_pred = self.transformer( - states=(latent_model_input, prompt_embeds, i), - timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - )[0] - noise_pred = noise_pred.float() - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - latents = latents.to(prompt_embeds.dtype) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=None)[0] - image = self.image_processor.postprocess(image, output_type="pil") - - # Offload all models - self.maybe_free_model_hooks() - +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers import AutoencoderKL + +from ..models import CogView3PlusTransformer2DModel +from ..schedulers import CogVideoXDDIMScheduler +from .pipeline_output import CogView3PipelineOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView3PlusPipeline(DiffusionPipeline): + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: CogView3PlusTransformer2DModel, + scheduler: CogVideoXDDIMScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 224, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, image_size, dtype, device): + height = image_size[0] + width = image_size[1] + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + latents = randn_tensor(shape, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_extra_step_kwargs(self, generator, eta): + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image_size: Tuple[int, int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + ) -> Union[CogView3PipelineOutput, Tuple]: + if image_size is None: + height = self.transformer.config.sample_size * self.vae_scale_factor + width = self.transformer.config.sample_size * self.vae_scale_factor + else: + height = image_size[0] + width = image_size[1] + + original_size = (height, width) + target_size = (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=224, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + (height, width), + prompt_embeds.dtype, + device, + ) + + extra_step_kwargs = self.prepare_extra_step_kwargs(None, 0.0) + + # 7. Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype) + crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + original_size = torch.cat([original_size, original_size]) + target_size = torch.cat([target_size, target_size]) + crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left]) + + original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + states=(latent_model_input, prompt_embeds, i), + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = latents.to(prompt_embeds.dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=None)[0] + image = self.image_processor.postprocess(image, output_type="pil") + + # Offload all models + self.maybe_free_model_hooks() + return CogView3PipelineOutput(images=image) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_output.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/pipeline/pipeline_output.py similarity index 96% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_output.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/pipeline/pipeline_output.py index 11f8976f0e..e56a4485d7 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_output.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/pipeline/pipeline_output.py @@ -1,21 +1,21 @@ -from dataclasses import dataclass -from typing import List, Union - -import numpy as np -import PIL.Image - -from diffusers.utils import BaseOutput - - -@dataclass -class CogView3PipelineOutput(BaseOutput): - """ - Output class for CogView3 pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - """ - +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from diffusers.utils import BaseOutput + + +@dataclass +class CogView3PipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/__init__.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/schedulers/__init__.py similarity index 99% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/__init__.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/schedulers/__init__.py index 7a8f559a28..f98b6e1dec 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/__init__.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/schedulers/__init__.py @@ -1,2 +1,2 @@ -from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler +from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler from .scheduling_utils import SchedulerMixin \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_ddim_cogvideox.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/schedulers/scheduling_ddim_cogvideox.py similarity index 97% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_ddim_cogvideox.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/schedulers/scheduling_ddim_cogvideox.py index b3f6ce229b..b4f81796e9 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_ddim_cogvideox.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/schedulers/scheduling_ddim_cogvideox.py @@ -1,276 +1,276 @@ -# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import BaseOutput -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin - - -@dataclass -class DDIMSchedulerOutput(BaseOutput): - prev_sample: torch.Tensor - pred_original_sample: Optional[torch.Tensor] = None - - -def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): - if alpha_transform_type == "cosine": - - def alpha_bar_fn(t): - return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 - - elif alpha_transform_type == "exp": - - def alpha_bar_fn(t): - return math.exp(t * -12.0) - - else: - raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -def rescale_zero_terminal_snr(alphas_cumprod): - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - - return alphas_bar - - -class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.00085, - beta_end: float = 0.0120, - beta_schedule: str = "scaled_linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - set_alpha_to_one: bool = True, - rescale_betas_zero_snr: bool = False, - snr_shift_scale: float = 3.0, - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - # Modify: SNR shift following SD3 - self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) - - # Rescale for zero SNR - if rescale_betas_zero_snr: - self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) - - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # setable values - self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) - - def _get_variance(self, timestep, prev_timestep): - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) - - return variance - - def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: - return sample - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - - Args: - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. - """ - - if num_inference_steps > self.config.num_train_timesteps: - raise ValueError( - f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" - f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" - f" maximal {self.config.num_train_timesteps} timesteps." - ) - - self.num_inference_steps = num_inference_steps - - if self.config.timestep_spacing == "linspace": - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) - .round()[::-1] - .copy() - .astype(np.int64) - ) - elif self.config.timestep_spacing == "leading": - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) - timesteps -= 1 - else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." - ) - - self.timesteps = torch.from_numpy(timesteps).to(device) - - def step( - self, - model_output: torch.Tensor, - timestep: int, - sample: torch.Tensor, - return_dict: bool = True, - ) -> Union[DDIMSchedulerOutput, Tuple]: - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps - - # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - - beta_prod_t = 1 - alpha_prod_t - - # 3. compute predicted original sample from predicted noise also called - if self.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.config.prediction_type == "sample": - pred_original_sample = model_output - elif self.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction`" - ) - - a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 - b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t - - prev_sample = a_t * sample + b_t * pred_original_sample - - if not return_dict: - return ( - prev_sample, - pred_original_sample, - ) - - return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.IntTensor, - ) -> torch.Tensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement - # for the subsequent add_noise calls - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) - alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: - # Make sure alphas_cumprod and timestep have same device and dtype as sample - self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) - alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) - timesteps = timesteps.to(sample.device) - - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(sample.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample - return velocity - - def __len__(self): +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +class DDIMSchedulerOutput(BaseOutput): + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + + return alphas_bar + + +class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.0120, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + set_alpha_to_one: bool = True, + rescale_betas_zero_snr: bool = False, + snr_shift_scale: float = 3.0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # Modify: SNR shift following SD3 + self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) + + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 + b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t + + prev_sample = a_t * sample + b_t * pred_original_sample + + if not return_dict: + return ( + prev_sample, + pred_original_sample, + ) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): return self.config.num_train_timesteps \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_utils.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/schedulers/scheduling_utils.py similarity index 97% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_utils.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/schedulers/scheduling_utils.py index d854366c77..eeb6e77dee 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_utils.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/schedulers/scheduling_utils.py @@ -1,113 +1,113 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -import os -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Union - -import torch -from huggingface_hub.utils import validate_hf_hub_args - -from diffusers.utils import BaseOutput, PushToHubMixin - - -SCHEDULER_CONFIG_NAME = "scheduler_config.json" - - -class KarrasDiffusionSchedulers(Enum): - DDIMScheduler = 1 - DDPMScheduler = 2 - PNDMScheduler = 3 - LMSDiscreteScheduler = 4 - EulerDiscreteScheduler = 5 - HeunDiscreteScheduler = 6 - EulerAncestralDiscreteScheduler = 7 - DPMSolverMultistepScheduler = 8 - DPMSolverSinglestepScheduler = 9 - KDPM2DiscreteScheduler = 10 - KDPM2AncestralDiscreteScheduler = 11 - DEISMultistepScheduler = 12 - UniPCMultistepScheduler = 13 - DPMSolverSDEScheduler = 14 - EDMEulerScheduler = 15 - - -AysSchedules = { - "StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24], - "StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0], - "StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13], - "StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0], - "StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0], -} - - -@dataclass -class SchedulerOutput(BaseOutput): - """ - Base class for the output of a scheduler's `step` function. - - Args: - prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - """ - - prev_sample: torch.Tensor - - -class SchedulerMixin(PushToHubMixin): - - config_name = SCHEDULER_CONFIG_NAME - _compatibles = [] - has_compatibles = True - - @classmethod - @validate_hf_hub_args - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, - subfolder: Optional[str] = None, - return_unused_kwargs=False, - **kwargs, - ): - - config, kwargs, _ = cls.load_config( - pretrained_model_name_or_path=pretrained_model_name_or_path, - subfolder=subfolder, - return_unused_kwargs=True, - return_commit_hash=True, - **kwargs, - ) - return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) - - @property - def compatibles(self): - """ - Returns all schedulers that are compatible with this scheduler - - Returns: - `List[SchedulerMixin]`: List of compatible schedulers - """ - return self._get_compatibles() - - @classmethod - def _get_compatibles(cls): - compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) - diffusers_library = importlib.import_module(__name__.split(".")[0]) - compatible_classes = [ - getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) - ] +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + +import torch +from huggingface_hub.utils import validate_hf_hub_args + +from diffusers.utils import BaseOutput, PushToHubMixin + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +class KarrasDiffusionSchedulers(Enum): + DDIMScheduler = 1 + DDPMScheduler = 2 + PNDMScheduler = 3 + LMSDiscreteScheduler = 4 + EulerDiscreteScheduler = 5 + HeunDiscreteScheduler = 6 + EulerAncestralDiscreteScheduler = 7 + DPMSolverMultistepScheduler = 8 + DPMSolverSinglestepScheduler = 9 + KDPM2DiscreteScheduler = 10 + KDPM2AncestralDiscreteScheduler = 11 + DEISMultistepScheduler = 12 + UniPCMultistepScheduler = 13 + DPMSolverSDEScheduler = 14 + EDMEulerScheduler = 15 + + +AysSchedules = { + "StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24], + "StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0], + "StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13], + "StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0], + "StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0], +} + + +@dataclass +class SchedulerOutput(BaseOutput): + """ + Base class for the output of a scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.Tensor + + +class SchedulerMixin(PushToHubMixin): + + config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + + config, kwargs, _ = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + return_commit_hash=True, + **kwargs, + ) + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] return compatible_classes \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/vae/__init__.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/vae/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/vae/__init__.py rename to MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/vae/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py b/MindIE/MultiModal/CogView3-Plus-3B/inference_cogview3plus.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py rename to MindIE/MultiModal/CogView3-Plus-3B/inference_cogview3plus.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt b/MindIE/MultiModal/CogView3-Plus-3B/requirents.txt similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt rename to MindIE/MultiModal/CogView3-Plus-3B/requirents.txt diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/__init__.py b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/__init__.py rename to MindIE/MultiModal/Flux.1-DEV/FLUX1dev/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/__init__.py b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/__init__.py rename to MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/attention_processor.py b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/attention_processor.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/attention_processor.py rename to MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/attention_processor.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/embedding.py b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/embedding.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/embedding.py rename to MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/embedding.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/__init__.py b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/models/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/__init__.py rename to MindIE/MultiModal/Flux.1-DEV/FLUX1dev/models/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/modeling_utils.py b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/models/modeling_utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/modeling_utils.py rename to MindIE/MultiModal/Flux.1-DEV/FLUX1dev/models/modeling_utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/transformer_flux.py b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/models/transformer_flux.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/transformer_flux.py rename to MindIE/MultiModal/Flux.1-DEV/FLUX1dev/models/transformer_flux.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/pipeline/__init__.py b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/pipeline/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/pipeline/__init__.py rename to MindIE/MultiModal/Flux.1-DEV/FLUX1dev/pipeline/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/pipeline/pipeline_flux.py b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/pipeline/pipeline_flux.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/pipeline/pipeline_flux.py rename to MindIE/MultiModal/Flux.1-DEV/FLUX1dev/pipeline/pipeline_flux.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/README.md b/MindIE/MultiModal/Flux.1-DEV/README.md similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/README.md rename to MindIE/MultiModal/Flux.1-DEV/README.md diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/inference_flux.py b/MindIE/MultiModal/Flux.1-DEV/inference_flux.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/inference_flux.py rename to MindIE/MultiModal/Flux.1-DEV/inference_flux.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/prompts.txt b/MindIE/MultiModal/Flux.1-DEV/prompts.txt similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/prompts.txt rename to MindIE/MultiModal/Flux.1-DEV/prompts.txt diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/requirements.txt b/MindIE/MultiModal/Flux.1-DEV/requirements.txt similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/requirements.txt rename to MindIE/MultiModal/Flux.1-DEV/requirements.txt diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/README.md b/MindIE/MultiModal/HunyuanDiT/README.md similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/README.md rename to MindIE/MultiModal/HunyuanDiT/README.md diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/__init__.py b/MindIE/MultiModal/HunyuanDiT/hydit/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/__init__.py rename to MindIE/MultiModal/HunyuanDiT/hydit/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/__init__.py b/MindIE/MultiModal/HunyuanDiT/hydit/layers/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/__init__.py rename to MindIE/MultiModal/HunyuanDiT/hydit/layers/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/activation.py b/MindIE/MultiModal/HunyuanDiT/hydit/layers/activation.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/activation.py rename to MindIE/MultiModal/HunyuanDiT/hydit/layers/activation.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/attention.py b/MindIE/MultiModal/HunyuanDiT/hydit/layers/attention.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/attention.py rename to MindIE/MultiModal/HunyuanDiT/hydit/layers/attention.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/embedding.py b/MindIE/MultiModal/HunyuanDiT/hydit/layers/embedding.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/embedding.py rename to MindIE/MultiModal/HunyuanDiT/hydit/layers/embedding.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/mlp.py b/MindIE/MultiModal/HunyuanDiT/hydit/layers/mlp.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/mlp.py rename to MindIE/MultiModal/HunyuanDiT/hydit/layers/mlp.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/norm.py b/MindIE/MultiModal/HunyuanDiT/hydit/layers/norm.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/norm.py rename to MindIE/MultiModal/HunyuanDiT/hydit/layers/norm.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/poolers.py b/MindIE/MultiModal/HunyuanDiT/hydit/layers/poolers.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/poolers.py rename to MindIE/MultiModal/HunyuanDiT/hydit/layers/poolers.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/__init__.py b/MindIE/MultiModal/HunyuanDiT/hydit/models/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/__init__.py rename to MindIE/MultiModal/HunyuanDiT/hydit/models/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/hydit.py b/MindIE/MultiModal/HunyuanDiT/hydit/models/hydit.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/hydit.py rename to MindIE/MultiModal/HunyuanDiT/hydit/models/hydit.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/model_load_utils.py b/MindIE/MultiModal/HunyuanDiT/hydit/models/model_load_utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/model_load_utils.py rename to MindIE/MultiModal/HunyuanDiT/hydit/models/model_load_utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/model_utils.py b/MindIE/MultiModal/HunyuanDiT/hydit/models/model_utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/model_utils.py rename to MindIE/MultiModal/HunyuanDiT/hydit/models/model_utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/__init__.py b/MindIE/MultiModal/HunyuanDiT/hydit/pipeline/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/__init__.py rename to MindIE/MultiModal/HunyuanDiT/hydit/pipeline/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/hydit_pipeline.py b/MindIE/MultiModal/HunyuanDiT/hydit/pipeline/hydit_pipeline.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/hydit_pipeline.py rename to MindIE/MultiModal/HunyuanDiT/hydit/pipeline/hydit_pipeline.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/schedulers/__init__.py b/MindIE/MultiModal/HunyuanDiT/hydit/schedulers/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/schedulers/__init__.py rename to MindIE/MultiModal/HunyuanDiT/hydit/schedulers/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/schedulers/ddpm.py b/MindIE/MultiModal/HunyuanDiT/hydit/schedulers/ddpm.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/schedulers/ddpm.py rename to MindIE/MultiModal/HunyuanDiT/hydit/schedulers/ddpm.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/__init__.py b/MindIE/MultiModal/HunyuanDiT/hydit/utils/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/__init__.py rename to MindIE/MultiModal/HunyuanDiT/hydit/utils/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/file_utils.py b/MindIE/MultiModal/HunyuanDiT/hydit/utils/file_utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/file_utils.py rename to MindIE/MultiModal/HunyuanDiT/hydit/utils/file_utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/utils.py b/MindIE/MultiModal/HunyuanDiT/hydit/utils/utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/utils.py rename to MindIE/MultiModal/HunyuanDiT/hydit/utils/utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/inference_hydit.py b/MindIE/MultiModal/HunyuanDiT/inference_hydit.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/inference_hydit.py rename to MindIE/MultiModal/HunyuanDiT/inference_hydit.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/lora/__init__.py b/MindIE/MultiModal/HunyuanDiT/lora/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/lora/__init__.py rename to MindIE/MultiModal/HunyuanDiT/lora/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/lora/hydit_lora.py b/MindIE/MultiModal/HunyuanDiT/lora/hydit_lora.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/lora/hydit_lora.py rename to MindIE/MultiModal/HunyuanDiT/lora/hydit_lora.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/prompts/example_prompts.txt b/MindIE/MultiModal/HunyuanDiT/prompts/example_prompts.txt similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/prompts/example_prompts.txt rename to MindIE/MultiModal/HunyuanDiT/prompts/example_prompts.txt diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/requirents.txt b/MindIE/MultiModal/HunyuanDiT/requirents.txt similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/requirents.txt rename to MindIE/MultiModal/HunyuanDiT/requirents.txt diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/README.md b/MindIE/MultiModal/OpenSora-v1.2/README.md similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/README.md rename to MindIE/MultiModal/OpenSora-v1.2/README.md diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/inference_opensora12.py b/MindIE/MultiModal/OpenSora-v1.2/inference_opensora12.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/inference_opensora12.py rename to MindIE/MultiModal/OpenSora-v1.2/inference_opensora12.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/__init__.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/__init__.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/__init__.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/layer/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/__init__.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/layer/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/activation.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/layer/activation.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/activation.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/layer/activation.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/attention.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/layer/attention.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/attention.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/layer/attention.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/comm.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/layer/comm.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/comm.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/layer/comm.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/conv.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/layer/conv.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/conv.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/layer/conv.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/embdding.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/layer/embdding.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/embdding.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/layer/embdding.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/mlp.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/layer/mlp.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/mlp.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/layer/mlp.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/norm.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/layer/norm.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/norm.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/layer/norm.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/parallel_mgr.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/layer/parallel_mgr.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/parallel_mgr.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/layer/parallel_mgr.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/utils.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/layer/utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/utils.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/layer/utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/__init__.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/pipeline/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/__init__.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/pipeline/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/compile_pipe.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/pipeline/compile_pipe.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/compile_pipe.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/pipeline/compile_pipe.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/open_sora_pipeline.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/pipeline/open_sora_pipeline.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/open_sora_pipeline.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/pipeline/open_sora_pipeline.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/pipeline_utils.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/pipeline/pipeline_utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/pipeline_utils.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/pipeline/pipeline_utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/schedulers/__init__.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/schedulers/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/schedulers/__init__.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/schedulers/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/schedulers/rectified_flow.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/schedulers/rectified_flow.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/schedulers/rectified_flow.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/schedulers/rectified_flow.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/stdit3/__init__.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/stdit3/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/stdit3/__init__.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/stdit3/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/stdit3/stdit3.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/stdit3/stdit3.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/stdit3/stdit3.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/stdit3/stdit3.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/__init__.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/utils/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/__init__.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/utils/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/patch_utils.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/utils/patch_utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/patch_utils.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/utils/patch_utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/utils.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/utils/utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/utils.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/utils/utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/VideoAutoencoder.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/vae/VideoAutoencoder.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/VideoAutoencoder.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/vae/VideoAutoencoder.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/__init__.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/vae/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/__init__.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/vae/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/vae_temporal.py b/MindIE/MultiModal/OpenSora-v1.2/opensora/vae/vae_temporal.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/vae_temporal.py rename to MindIE/MultiModal/OpenSora-v1.2/opensora/vae/vae_temporal.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/prompts/t2v_sora.txt b/MindIE/MultiModal/OpenSora-v1.2/prompts/t2v_sora.txt similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/prompts/t2v_sora.txt rename to MindIE/MultiModal/OpenSora-v1.2/prompts/t2v_sora.txt diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/requirents.txt b/MindIE/MultiModal/OpenSora-v1.2/requirents.txt similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/opensora1.2/requirents.txt rename to MindIE/MultiModal/OpenSora-v1.2/requirents.txt diff --git a/MindIE/MultiModal/OpenSoraPlan-v1.3/README.md b/MindIE/MultiModal/OpenSoraPlan-v1.3/README.md new file mode 100644 index 0000000000..407937acad --- /dev/null +++ b/MindIE/MultiModal/OpenSoraPlan-v1.3/README.md @@ -0,0 +1,182 @@ +--- +license: apache-2.0 +--- + + + +# Opensoraplan1.3 + +## 一、介绍 +此仓库是开源模型Opensoraplan1.3, 基于MindIE SD 的实现。 可以实现更高效的推理性能。运行此仓库代码,需要安装MindIE SD 及其依赖。 + + +## 二、安装依赖 + +MindIE SD是MindIE的视图生成推理模型套件,其目标是为稳定扩散(Stable Diffusion, SD)系列大模型推理任务提供在昇腾硬件及其软件栈上的端到端解决方案,软件系统内部集成各功能模块,对外呈现统一的编程接口。 + +MindIE-SD其依赖组件为driver驱动包、firmware固件包、CANN开发套件包、推理引擎MindIE包,使用MindIE-SD前请提前安装这些依赖。 + +| 简称 | 安装包全名 | 默认安装路径 | 版本约束 | +| --------------- |---------------------------------------------------------------------------|--------------------------------------|-----------------------------------| +| driver驱动包 | 昇腾310P处理器对应驱动软件包:Ascend-hdk-310p-npu-driver_\{version\}\_{os}\-{arch}.run | /usr/local/Ascend | 24.0.rc1及以上 | +| firmware固件包 | 昇腾310P处理器对应固件软件包:Ascend-hdk-310p-npu-firmware_\{version\}.run | /usr/local/Ascend | 24.0.rc1及以上 | +| CANN开发套件包 | Ascend-cann-toolkit\_{version}_linux-{arch}.run | /usr/local/Ascend/ascend-toolkit/latest | 8.0.RC1及以上 | +| 推理引擎MindIE包 | Ascend-mindie\_\{version}_linux-\{arch}.run | /usr/local/Ascend/mindie/latest | 和mindietorch严格配套使用 | +| torch | Python的whl包:torch-{version}-cp310-cp310-{os}_{arch}.whl | - | Python版本3.10.x,torch版本支持2.1.0 | + +- {version}为软件包版本 +- {os}为系统名称,如Linux +- {arch}为架构名称,如x86_64 + +### 2.1 安装驱动和固件 + +1. 获取地址 +- [Atlas 800I A2(8*64G)](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=32&cann=8.0.RC1.beta1&driver=1.0.RC1.alpha) +2. [安装指导手册](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0019.html) +### 2.2 CANN开发套件包+kernel包+MindIE包下载 +1. 下载: +- [Atlas 800I A2(8*64G)](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +2. [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +3. 快速安装: +- CANN开发套件包+kernel包安装 +```commandline +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` +- MindIE包安装 +```commandline +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +- MindIE SD不需要单独安装,安装MindIE时将会自动安装 +- torch_npu 安装: +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```commandline +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` + +### 2.3 pytorch框架(支持版本为:2.1.0) +[安装包下载](https://download.pytorch.org/whl/cpu/torch/) + +使用pip安装 +```shell +# {version}表示软件版本号,{arch}表示CPU架构。 +pip install torch-${version}-cp310-cp310-linux_${arch}.whl +``` + +### 2.4 安装依赖库 +安装MindIE-SD的依赖库。 +``` +pip install -r requirements.txt +``` + +## 三、Opensoraplan1.3 + +### 3.1 权重及配置文件说明 + +1. text_encoder和tokenizer: +- 配置文件和权重文件 +```shell +https://huggingface.co/google/mt5-xxl +``` +2. transformer: +- 配置文件和权重文件 +```shell + https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/any93x640x640 +``` +3. VAE: +- 配置文件和权重文件 +```shell +https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main +``` + +### 3.2 执行推理脚本 +```shell +ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node 4 --master_port 29516 \ + inference_opensoraplan13.py \ + --model_path /path/to/transformer/ \ + --num_frames 93 \ + --height 640 \ + --width 640 \ + --text_encoder_name_1 "/path/to/text/encoder" \ + --text_prompt prompt.txt \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path "/home/liuyaofu/planweight/vae" \ + --save_img_path "./video/save/path" \ + --fps 24 \ + --guidance_scale 7.5 \ + --num_sampling_steps 100 \ + --max_sequence_length 512 \ + --seed 1234 \ + --num_samples_per_prompt 1 \ + --rescale_betas_zero_snr \ + --prediction_type "v_prediction" \ + --save_memory \ + --sp \ + --use_cache +``` +ASCEND_RT_VISIBLE_DEVICES 指定特定的NPU进行计算 +--nproc_per_node 控制总NPU卡数进行计算 + +--model_path 指定transformers(DiT)模型权重配置路径, 下面包含config文件和权重文件 +--num_frames 设置生成的总帧数 +--height 设置输出图像的高度为多少像素 +--width 设置输出图像的宽度为多少像素 +--text_encoder_name_1 指定text_encoder权重配置路径 +--text_prompt 指定输入的文本提示, 可以是一个txt文件或者一个prompt字符 +--ae VAE的对视频的压缩规格 +--ae_path 指定VAE模型权重配置路径 +--save_img_path 指定视频保存的路径 +--fps 设置帧率 +--guidance_scale 设置引导比例,用于控制negative文本对视频生成的影响程度 +--num_sampling_steps 设置采样步骤的数量 +--max_sequence_length 512 设置prompt的最大长度为512 +--num_samples_per_prompt 1 设置每个提示生成的样本数为1 +--rescale_betas_zero_snr schedular 的配置 +--prediction_type schedular 的配置 +--save_memory 运行VAE时尽量节省内存, 当生成视频较大时,要开启 +--sp 是否开启序列并行 +--use_cache 是否开启dit cache算法 + + +## 四、reference + +### 4.1 EulerAncestralDiscreteScheduler + +本项目中的 `EulerAncestralDiscreteScheduler` 是从 [Hugging Face diffusers 库](https://github.com/huggingface/diffusers) 中引用的 `EulerAncestralDiscreteScheduler`。diffusers 库是一个用于生成扩散模型(diffusion models)的工具库,它提供了多种调度器(schedulers)来控制扩散过程。 + +在 `EulerAncestralDiscreteScheduler` 中,使用了 "linspace", "leading", 和 "trailing" 这几个概念。这些概念在下述论文中有所描述。 + +- 链接:[https://arxiv.org/abs/2305.08891](https://arxiv.org/abs/2305.08891) + + +### 许可证 + +本项目遵循 Apache License 2.0。有关详细信息,请参阅 [LICENSE](LICENSE) 文件。 + + diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/inference_opensoraplan13.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/inference_opensoraplan13.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/inference_opensoraplan13.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/inference_opensoraplan13.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/__init__.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/layers/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/__init__.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/layers/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/activation.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/layers/activation.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/activation.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/layers/activation.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/attention.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/layers/attention.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/attention.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/layers/attention.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/cache_mgr.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/layers/cache_mgr.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/cache_mgr.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/layers/cache_mgr.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/conv.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/layers/conv.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/conv.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/layers/conv.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/linear.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/layers/linear.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/linear.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/layers/linear.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/mlp.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/layers/mlp.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/mlp.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/layers/mlp.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/norm.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/layers/norm.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/norm.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/layers/norm.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/sampling.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/layers/sampling.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/sampling.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/layers/sampling.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/utils.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/layers/utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/utils.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/layers/utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/vresnet.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/layers/vresnet.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/vresnet.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/layers/vresnet.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/comm.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/models/comm.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/comm.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/models/comm.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/model_utils.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/models/model_utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/model_utils.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/models/model_utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/parallel_mgr.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/models/parallel_mgr.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/parallel_mgr.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/models/parallel_mgr.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/t2vdit.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/models/t2vdit.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/t2vdit.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/models/t2vdit.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/wfvae.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/models/wfvae.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/wfvae.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/models/wfvae.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/pipeline/open_soar_plan_pipeline.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/pipeline/open_soar_plan_pipeline.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/pipeline/open_soar_plan_pipeline.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/pipeline/open_soar_plan_pipeline.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/pipeline/pipeline_utils.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/pipeline/pipeline_utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/pipeline/pipeline_utils.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/pipeline/pipeline_utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/utils/__init__.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/utils/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/utils/__init__.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/utils/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/utils/utils.py b/MindIE/MultiModal/OpenSoraPlan-v1.3/utils/utils.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/utils/utils.py rename to MindIE/MultiModal/OpenSoraPlan-v1.3/utils/utils.py -- Gitee From 580bdf1b1f8dfaed825be2e62e5089d681ab75dd Mon Sep 17 00:00:00 2001 From: Logan Date: Mon, 10 Feb 2025 20:08:52 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E8=BF=81=E7=A7=BBSD=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=88=B0MultiModal=E8=B7=AF=E5=BE=84=E4=B8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../foundation/CogVideoX-5b}/README.md | 0 .../CogVideoX-5b}/cogvideox_5b/__init__.py | 0 .../cogvideox_5b/models/__init__.py | 0 .../cogvideox_5b/models/activations.py | 0 .../cogvideox_5b/models/attention.py | 0 .../models/attention_processor.py | 0 .../cogvideox_5b/models/embeddings.py | 0 .../cogvideox_5b/models/normalization.py | 0 .../models/transformers/__init__.py | 0 .../transformers/cogvideox_transformer_3d.py | 0 .../cogvideox_5b/pipelines/__init__.py | 0 .../pipelines/pipeline_cogvideox.py | 0 .../cogvideox_5b/pipelines/pipeline_output.py | 0 .../cogvideox_5b/utils/__init__.py | 0 .../cogvideox_5b/utils/parallel_mgr.py | 0 .../cogvideox_5b/utils/parallel_state.py | 0 .../foundation/CogVideoX-5b}/inference.py | 0 .../CogVideoX-5b}/pta_plugin/CMakeLists.txt | 0 .../CogVideoX-5b}/pta_plugin/build.sh | 0 .../pta_plugin/extension_ops.cpp | 0 .../pta_plugin/test/test_rope.py | 0 .../foundation/CogVideoX-5b}/requirements.txt | 0 .../Flux.1-DEV/FLUX1dev/__init__.py | 17 + .../Flux.1-DEV/FLUX1dev/layers/__init__.py | 17 + .../FLUX1dev/layers/attention_processor.py | 215 + .../Flux.1-DEV/FLUX1dev/layers/embedding.py | 115 + .../Flux.1-DEV/FLUX1dev/models/__init__.py | 17 + .../FLUX1dev/models/modeling_utils.py | 16 + .../FLUX1dev/models/transformer_flux.py | 457 ++ .../Flux.1-DEV/FLUX1dev/pipeline/__init__.py | 16 + .../FLUX1dev/pipeline/pipeline_flux.py | 759 +++ .../built-in/foundation/Flux.1-DEV/README.md | 148 + .../foundation/Flux.1-DEV/inference_flux.py | 149 + .../foundation/Flux.1-DEV/prompts.txt | 16 + .../foundation/Flux.1-DEV/requirements.txt | 9 + .../built-in/foundation/cogview3/README.md | 167 + .../cogview3/cogview3plus/__init__.py | 3 + .../cogview3/cogview3plus/layers/__init__.py | 3 + .../cogview3plus/layers/embeddings.py | 304 ++ .../cogview3/cogview3plus/layers/linear.py | 48 + .../cogview3plus/layers/normalization.py | 177 + .../cogview3/cogview3plus/models/__init__.py | 2 + .../cogview3plus/models/activations.py | 163 + .../cogview3/cogview3plus/models/attention.py | 87 + .../models/attention_processor.py | 348 ++ .../cogview3plus/models/model_load_utils.py | 42 + .../cogview3plus/models/modeling_utils.py | 771 +++ .../models/transformer_cogview3plus.py | 397 ++ .../cogview3plus/pipeline/__init__.py | 1 + .../pipeline/pipeline_cogview3plus.py | 339 ++ .../cogview3plus/pipeline/pipeline_output.py | 21 + .../cogview3plus/schedulers/__init__.py | 2 + .../schedulers/scheduling_ddim_cogvideox.py | 276 ++ .../schedulers/scheduling_utils.py | 113 + .../cogview3/cogview3plus/vae/__init__.py | 0 .../cogview3/inference_cogview3plus.py | 105 + .../foundation/cogview3/requirents.txt | 8 + .../built-in/foundation/hunyuan_dit/README.md | 411 ++ .../foundation/hunyuan_dit/hydit/__init__.py | 21 + .../hunyuan_dit/hydit/layers/__init__.py | 23 + .../hunyuan_dit/hydit/layers/activation.py | 49 + .../hunyuan_dit/hydit/layers/attention.py | 141 + .../hunyuan_dit/hydit/layers/embedding.py | 713 +++ .../hunyuan_dit/hydit/layers/mlp.py | 64 + .../hunyuan_dit/hydit/layers/norm.py | 54 + .../hunyuan_dit/hydit/layers/poolers.py | 56 + .../hunyuan_dit/hydit/models/__init__.py | 18 + .../hunyuan_dit/hydit/models/hydit.py | 407 ++ .../hydit/models/model_load_utils.py | 43 + .../hunyuan_dit/hydit/models/model_utils.py | 77 + .../hunyuan_dit/hydit/pipeline/__init__.py | 18 + .../hydit/pipeline/hydit_pipeline.py | 399 ++ .../hunyuan_dit/hydit/schedulers/__init__.py | 18 + .../hunyuan_dit/hydit/schedulers/ddpm.py | 115 + .../hunyuan_dit/hydit/utils/__init__.py | 18 + .../hunyuan_dit/hydit/utils/file_utils.py | 139 + .../hunyuan_dit/hydit/utils/utils.py | 107 + .../foundation/hunyuan_dit/inference_hydit.py | 335 ++ .../foundation/hunyuan_dit/lora/__init__.py | 18 + .../foundation/hunyuan_dit/lora/hydit_lora.py | 61 + .../hunyuan_dit/prompts/example_prompts.txt | 28 + .../foundation/hunyuan_dit/requirents.txt | 18 + .../inference_opensoraplan13.py | 162 + .../open_sora_planv1_3/layers/__init__.py | 6 + .../open_sora_planv1_3/layers/activation.py | 57 + .../open_sora_planv1_3/layers/attention.py | 392 ++ .../open_sora_planv1_3/layers/cache_mgr.py | 172 + .../open_sora_planv1_3/layers/conv.py | 152 + .../open_sora_planv1_3/layers/linear.py | 96 + .../open_sora_planv1_3/layers/mlp.py | 61 + .../open_sora_planv1_3/layers/norm.py | 65 + .../open_sora_planv1_3/layers/sampling.py | 131 + .../open_sora_planv1_3/layers/utils.py | 55 + .../open_sora_planv1_3/layers/vresnet.py | 127 + .../open_sora_planv1_3/models/comm.py | 180 + .../open_sora_planv1_3/models/model_utils.py | 65 + .../open_sora_planv1_3/models/parallel_mgr.py | 67 + .../open_sora_planv1_3/models/t2vdit.py | 458 ++ .../open_sora_planv1_3/models/wfvae.py | 576 +++ .../pipeline/open_soar_plan_pipeline.py | 463 ++ .../pipeline/pipeline_utils.py | 131 + .../open_sora_planv1_3/utils/__init__.py | 1 + .../open_sora_planv1_3/utils/utils.py | 19 + .../built-in/foundation/opensora1.2/README.md | 213 + .../opensora1.2/inference_opensora12.py | 182 + .../opensora1.2/opensora/__init__.py | 26 + .../opensora1.2/opensora/layer/__init__.py | 33 + .../opensora1.2/opensora/layer/activation.py | 21 + .../opensora1.2/opensora/layer/attention.py | 274 ++ .../opensora1.2/opensora/layer/comm.py | 121 + .../opensora1.2/opensora/layer/conv.py | 247 + .../opensora1.2/opensora/layer/embdding.py | 423 ++ .../opensora1.2/opensora/layer/mlp.py | 62 + .../opensora1.2/opensora/layer/norm.py | 151 + .../opensora/layer/parallel_mgr.py | 59 + .../opensora1.2/opensora/layer/utils.py | 28 + .../opensora1.2/opensora/pipeline/__init__.py | 18 + .../opensora/pipeline/compile_pipe.py | 33 + .../opensora/pipeline/open_sora_pipeline.py | 262 + .../opensora/pipeline/pipeline_utils.py | 169 + .../opensora/schedulers/__init__.py | 18 + .../opensora/schedulers/rectified_flow.py | 102 + .../opensora1.2/opensora/stdit3/__init__.py | 18 + .../opensora1.2/opensora/stdit3/stdit3.py | 563 +++ .../opensora1.2/opensora/utils/__init__.py | 24 + .../opensora1.2/opensora/utils/patch_utils.py | 76 + .../opensora1.2/opensora/utils/utils.py | 140 + .../opensora/vae/VideoAutoencoder.py | 199 + .../opensora1.2/opensora/vae/__init__.py | 17 + .../opensora1.2/opensora/vae/vae_temporal.py | 472 ++ .../opensora1.2/prompts/t2v_sora.txt | 48 + .../foundation/opensora1.2/requirents.txt | 16 + MindIE/MultiModal/CogVideoX/README.md | 181 + .../CogVideoX/cogvideox_5b/__init__.py | 4 + .../CogVideoX/cogvideox_5b/models/__init__.py | 1 + .../cogvideox_5b/models/activations.py | 165 + .../cogvideox_5b/models/attention.py | 1228 +++++ .../models/attention_processor.py | 4320 +++++++++++++++++ .../cogvideox_5b/models/embeddings.py | 1808 +++++++ .../cogvideox_5b/models/normalization.py | 527 ++ .../models/transformers/__init__.py | 1 + .../transformers/cogvideox_transformer_3d.py | 551 +++ .../cogvideox_5b/pipelines/__init__.py | 1 + .../pipelines/pipeline_cogvideox.py | 759 +++ .../cogvideox_5b/pipelines/pipeline_output.py | 20 + .../CogVideoX/cogvideox_5b/utils/__init__.py | 2 + .../cogvideox_5b/utils/parallel_mgr.py | 76 + .../cogvideox_5b/utils/parallel_state.py | 168 + MindIE/MultiModal/CogVideoX/inference.py | 134 + .../CogVideoX/pta_plugin/CMakeLists.txt | 30 + .../MultiModal/CogVideoX/pta_plugin/build.sh | 19 + .../CogVideoX/pta_plugin/extension_ops.cpp | 69 + .../CogVideoX/pta_plugin/test/test_rope.py | 25 + MindIE/MultiModal/CogVideoX/requirements.txt | 13 + 154 files changed, 26232 insertions(+) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/README.md (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/__init__.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/models/__init__.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/models/activations.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/models/attention.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/models/attention_processor.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/models/embeddings.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/models/normalization.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/models/transformers/__init__.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/pipelines/__init__.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/pipelines/pipeline_cogvideox.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/pipelines/pipeline_output.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/utils/__init__.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/utils/parallel_mgr.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/cogvideox_5b/utils/parallel_state.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/inference.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/pta_plugin/CMakeLists.txt (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/pta_plugin/build.sh (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/pta_plugin/extension_ops.cpp (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/pta_plugin/test/test_rope.py (100%) rename MindIE/{MultiModal/CogVideoX-5B => MindIE-Torch/built-in/foundation/CogVideoX-5b}/requirements.txt (100%) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/attention_processor.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/embedding.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/modeling_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/transformer_flux.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/pipeline/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/pipeline/pipeline_flux.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/README.md create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/inference_flux.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/prompts.txt create mode 100644 MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/requirements.txt create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/embeddings.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/linear.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/normalization.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/activations.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention_processor.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/model_load_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/modeling_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_output.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_ddim_cogvideox.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/vae/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/README.md create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/activation.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/attention.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/embedding.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/mlp.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/norm.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/poolers.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/hydit.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/model_load_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/model_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/hydit_pipeline.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/schedulers/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/schedulers/ddpm.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/file_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/inference_hydit.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/lora/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/lora/hydit_lora.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/prompts/example_prompts.txt create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/requirents.txt create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/inference_opensoraplan13.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/activation.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/attention.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/cache_mgr.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/conv.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/linear.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/mlp.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/norm.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/sampling.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/vresnet.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/comm.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/model_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/parallel_mgr.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/t2vdit.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/wfvae.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/pipeline/open_soar_plan_pipeline.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/pipeline/pipeline_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/utils/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/utils/utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/README.md create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/inference_opensora12.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/activation.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/attention.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/comm.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/conv.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/embdding.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/mlp.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/norm.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/parallel_mgr.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/compile_pipe.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/open_sora_pipeline.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/pipeline_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/schedulers/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/schedulers/rectified_flow.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/stdit3/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/stdit3/stdit3.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/patch_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/VideoAutoencoder.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/vae_temporal.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/prompts/t2v_sora.txt create mode 100644 MindIE/MindIE-Torch/built-in/foundation/opensora1.2/requirents.txt create mode 100644 MindIE/MultiModal/CogVideoX/README.md create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/__init__.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/models/__init__.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/models/activations.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/models/attention.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/models/attention_processor.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/models/embeddings.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/models/normalization.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/models/transformers/__init__.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/pipelines/__init__.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/pipelines/pipeline_cogvideox.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/pipelines/pipeline_output.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/__init__.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/parallel_mgr.py create mode 100644 MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/parallel_state.py create mode 100644 MindIE/MultiModal/CogVideoX/inference.py create mode 100644 MindIE/MultiModal/CogVideoX/pta_plugin/CMakeLists.txt create mode 100644 MindIE/MultiModal/CogVideoX/pta_plugin/build.sh create mode 100644 MindIE/MultiModal/CogVideoX/pta_plugin/extension_ops.cpp create mode 100644 MindIE/MultiModal/CogVideoX/pta_plugin/test/test_rope.py create mode 100644 MindIE/MultiModal/CogVideoX/requirements.txt diff --git a/MindIE/MultiModal/CogVideoX-5B/README.md b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/README.md similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/README.md rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/README.md diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/__init__.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/__init__.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/__init__.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/__init__.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/__init__.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/__init__.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/activations.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/activations.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/activations.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/activations.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/attention.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/attention.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/attention_processor.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/embeddings.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/embeddings.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/embeddings.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/embeddings.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/normalization.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/normalization.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/normalization.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/normalization.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/transformers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/__init__.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/transformers/__init__.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/__init__.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/pipelines/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/__init__.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/pipelines/__init__.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/__init__.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/pipelines/pipeline_cogvideox.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/pipelines/pipeline_cogvideox.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/pipelines/pipeline_output.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_output.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/pipelines/pipeline_output.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_output.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/utils/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/__init__.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/utils/__init__.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/__init__.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/utils/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_mgr.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/utils/parallel_mgr.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_mgr.py diff --git a/MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/utils/parallel_state.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_state.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/cogvideox_5b/utils/parallel_state.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_state.py diff --git a/MindIE/MultiModal/CogVideoX-5B/inference.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/inference.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py diff --git a/MindIE/MultiModal/CogVideoX-5B/pta_plugin/CMakeLists.txt b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/CMakeLists.txt similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/pta_plugin/CMakeLists.txt rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/CMakeLists.txt diff --git a/MindIE/MultiModal/CogVideoX-5B/pta_plugin/build.sh b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/build.sh similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/pta_plugin/build.sh rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/build.sh diff --git a/MindIE/MultiModal/CogVideoX-5B/pta_plugin/extension_ops.cpp b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/extension_ops.cpp similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/pta_plugin/extension_ops.cpp rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/extension_ops.cpp diff --git a/MindIE/MultiModal/CogVideoX-5B/pta_plugin/test/test_rope.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/test/test_rope.py similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/pta_plugin/test/test_rope.py rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/test/test_rope.py diff --git a/MindIE/MultiModal/CogVideoX-5B/requirements.txt b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/requirements.txt similarity index 100% rename from MindIE/MultiModal/CogVideoX-5B/requirements.txt rename to MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/requirements.txt diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/__init__.py new file mode 100644 index 0000000000..ad414be943 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .models import FluxTransformer2DModel, ModelMixin +from .pipeline import FluxPipeline, DiffusionPipeline \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/__init__.py new file mode 100644 index 0000000000..80f6776c05 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .attention_processor import (FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0) +from .embedding import FluxPosEmbed \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/attention_processor.py new file mode 100644 index 0000000000..02d59c8aac --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/attention_processor.py @@ -0,0 +1,215 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch_npu +import torch.nn.functional as F + +from diffusers.models.attention_processor import Attention +from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding + + +def apply_rotary_emb_mindspeed(x, freqs_cis): + cos, sin = freqs_cis + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + return npu_rotary_position_embedding(x, cos, sin, 1) + + +def apply_fa(query, key, value, attention_mask): + if attention_mask is not None: + attention_mask = ~attention_mask + attention_mask = attention_mask.to(torch.bool) + batch_size = query.shape[0] + heads = query.shape[1] + head_dim = query.shape[-1] + actseqlen = query.shape[-2] + actseqlenkv = key.shape[-2] + + hidden_states, _ = torch_npu.npu_fused_infer_attention_score(query, key, value, + actual_seq_lengths=[actseqlen], actual_seq_lengths_kv=[actseqlenkv], + num_heads=heads, input_layout="BNSD", scale=head_dim ** -0.5, pre_tokens=65535, next_tokens=65535) + return hidden_states.transpose(1, 2).reshape(batch_size, -1, head_dim * heads) + + +def rms_norm_npu(hidden_states, weight, eps): + return torch_npu.npu_rms_norm(hidden_states, weight, eps)[0] + + +# YiYi to-do: refactor rope related functions/classes +def apply_rope(xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +class FluxSingleAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = rms_norm_npu(query, attn.norm_q.weight, attn.norm_q.eps) + if attn.norm_k is not None: + key = rms_norm_npu(key, attn.norm_k.weight, attn.norm_k.eps) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb_mindspeed(query, image_rotary_emb) + key = apply_rotary_emb_mindspeed(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = apply_fa(query, key, value, attention_mask) + hidden_states = hidden_states.to(query.dtype) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states + + +class FluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = rms_norm_npu(query, attn.norm_q.weight, attn.norm_q.eps) + if attn.norm_k is not None: + key = rms_norm_npu(key, attn.norm_k.weight, attn.norm_k.eps) + + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = rms_norm_npu(encoder_hidden_states_query_proj, attn.norm_added_q.weight, attn.norm_added_q.eps) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = rms_norm_npu(encoder_hidden_states_key_proj, attn.norm_added_k.weight, attn.norm_added_k.eps) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + query = apply_rotary_emb_mindspeed(query, image_rotary_emb) + key = apply_rotary_emb_mindspeed(key, image_rotary_emb) + + + hidden_states = apply_fa(query, key, value, attention_mask) + hidden_states = hidden_states.to(query.dtype) + + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/embedding.py b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/embedding.py new file mode 100644 index 0000000000..70f85a04f1 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/layers/embedding.py @@ -0,0 +1,115 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Union, List +import torch +import torch_npu +import numpy as np +from torch import nn + + +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + freqs_dtype = torch.float32 if is_mps else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs).float() # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio, allegro + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/__init__.py new file mode 100644 index 0000000000..095ba2fce2 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .modeling_utils import ModelMixin +from .transformer_flux import FluxTransformer2DModel \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/modeling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/modeling_utils.py new file mode 100644 index 0000000000..1c9a94f801 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/modeling_utils.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from diffusers.models.modeling_utils import ModelMixin \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/transformer_flux.py b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/transformer_flux.py new file mode 100644 index 0000000000..bf58ff520d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/models/transformer_flux.py @@ -0,0 +1,457 @@ +# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, List, Optional, Union + +import torch +import torch_npu +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings +from diffusers.models.modeling_outputs import Transformer2DModelOutput + +from ..layers import FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0 +from .modeling_utils import ModelMixin +from ..layers import FluxPosEmbed + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# YiYi to-do: refactor rope related functions/classes +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + + stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) + return out.float() + + +# YiYi to-do: refactor rope related functions/classes +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: List[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(1) + + +@maybe_allow_in_graph +class FluxSingleTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + processor = FluxSingleAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm="rms_norm", + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + ): + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states + + +@maybe_allow_in_graph +class FluxTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + + self.norm1_context = AdaLayerNormZero(dim) + + if hasattr(F, "scaled_dot_product_attention"): + processor = FluxAttnProcessor2_0() + else: + raise ValueError( + "The current PyTorch version does not support the `scaled_dot_product_attention` function." + ) + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + ): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + The Transformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Parameters: + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. + num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: List[int] = [16, 56, 56], + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim + ) + + self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + image_rotary_emb: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/pipeline/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/pipeline/__init__.py new file mode 100644 index 0000000000..2385615593 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/pipeline/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .pipeline_flux import FluxPipeline, DiffusionPipeline \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/pipeline/pipeline_flux.py b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/pipeline/pipeline_flux.py new file mode 100644 index 0000000000..359b0bb3cc --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/FLUX1dev/pipeline/pipeline_flux.py @@ -0,0 +1,759 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FluxLoraLoaderMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) + +from ..models import FluxTransformer2DModel + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxPipeline + + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if text_ids.ndim == 3: + text_ids = text_ids[0] + if latent_image_ids.ndim == 3: + latent_image_ids = latent_image_ids[0] + ids = torch.cat((text_ids, latent_image_ids), dim=0) + image_rotary_emb = self.transformer.pos_embed(ids) + image_rotary_emb = [freq.to(torch.bfloat16) for freq in image_rotary_emb] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.tensor([guidance_scale], device=device) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + noise_pred = self.transformer( + hidden_states=latents, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + image_rotary_emb=image_rotary_emb, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/README.md b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/README.md new file mode 100644 index 0000000000..886d74b658 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/README.md @@ -0,0 +1,148 @@ +## 一、准备运行环境 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | + +### 1.1 获取CANN&MindIE安装包&环境准备 +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1 +- [Atlas 800I A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +### 1.2 CANN安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +### 1.3 环境依赖安装 +```shell +pip3 install -r requirements.txt +``` + +### 1.4 MindIE安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +### 1.5 Torch_npu安装 +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```shell +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` + +### 安装mindspeed依赖 +```shell +# 下载mindspeed源码仓: +git clone https://gitee.com/ascend/MindSpeed.git +# 执行如下命令进行安装: +pip install -e MindSpeed +``` +## 二、下载本仓库 + +### 2.1 下载到本地 +```shell +git clone https://modelers.cn/MindIE/FLUX.1-dev.git +``` +## 三、Flux.1-DEV使用 + +### 3.1 准备权重 +Flux.1-DEV权重下载地址 +```shell +https://huggingface.co/black-forest-labs/FLUX.1-dev/tree/main +``` + +设置模型权重路径环境变量: +```bash +export model_path="your local flux model path" +``` +修改权重配置文件: +```bash +vi ${model_path}/model_index.json +```` +做如下修改: +```json +{ + "_class_name": "FluxPipeline", + "_diffusers_version": "0.30.0.dev0", + "scheduler": [ + "diffusers", + "FlowMatchEulerDiscreteScheduler" + ], + "text_encoder": [ + "transformers", + "CLIPTextModel" + ], + "text_encoder_2": [ + "transformers", + "T5EncoderModel" + ], + "tokenizer": [ + "transformers", + "CLIPTokenizer" + ], + "tokenizer_2": [ + "transformers", + "T5TokenizerFast" + ], + "transformer": [ + "FLUX1dev", + "FluxTransformer2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ] +} +``` +### 3.2 运行Flux +```shell +python inference_flux.py \ + --path ${model_path} \ + --save_path "./res" \ + --device_id 0 \ + --device "npu" \ + --prompt_path "./prompts.txt" \ + --width 1024 \ + --height 1024 \ + --infer_steps 50 \ + --seed 42 +``` +参数说明: +- path: Flux本地模型权重路径,默认读取当前文件夹下的flux文件夹 +- save_path: 保存图像路径,默认当前文件夹下的res文件夹 +- device_id: 推理设备ID,默认值设置为0 +- device: 推理设备类型,默认为npu +- prompt_path: 用于图像生成的文字描述提示的列表文件路径 +- width: 图像生成的宽度,默认1024 +- height: 图像生成的高度,默认1024 +- infer_steps: Flux图像推理步数,默认值为50 +- seed: 设置随机种子,默认值为42 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/inference_flux.py b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/inference_flux.py new file mode 100644 index 0000000000..b2bab68afe --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/inference_flux.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import argparse +import time +import torch +import torch_npu +from FLUX1dev import FluxPipeline + +from torch_npu.contrib import transfer_to_npu + +torch_npu.npu.set_compile_mode(jit_compile=False) + + +class PromptLoader: + def __init__( + self, + prompt_file: str, + batch_size: int = 1, + num_images_per_prompt: int = 1, + max_num_prompts: int = 0 + ): + self.prompts = [] + self.catagories = ['Not_specified'] + self.batch_size = batch_size + self.num_images_per_prompt = num_images_per_prompt + + self.load_prompts(prompt_file, max_num_prompts) + + self.current_id = 0 + self.inner_id = 0 + + def __len__(self): + return len(self.prompts) * self.num_images_per_prompt + + def __iter__(self): + return self + + def __next__(self): + if self.current_id == len(self.prompts): + raise StopIteration + + ret = { + 'prompts': [], + 'catagories': [], + 'save_names': [], + 'n_prompts': self.batch_size, + } + for _ in range(self.batch_size): + if self.current_id == len(self.prompts): + ret['prompts'].append('') + ret['save_names'].append('') + ret['catagories'].append('') + ret['n_prompts'] -= 1 + + else: + prompt, catagory_id = self.prompts[self.current_id] + ret['prompts'].append(prompt) + ret['catagories'].append(self.catagories[catagory_id]) + ret['save_names'].append(f'{self.current_id}_{self.inner_id}') + + self.inner_id += 1 + if self.inner_id == self.num_images_per_prompt: + self.inner_id = 0 + self.current_id += 1 + + return ret + + def load_prompts(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + for i, line in enumerate(f): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line.strip() + self.prompts.append((prompt, 0)) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--path", type=str, default="./flux", help="Path to the flux model directory") + parser.add_argument("--save_path", type=str, default="./res", help="ouput image path") + parser.add_argument("--device_id", type=int, default=0, help="NPU device id") + parser.add_argument("--device", type=str, default="npu", help="NPU") + parser.add_argument("--prompt_path", type=str, default="./prompts.txt", help="input prompt text path") + parser.add_argument("--width", type=int, default=1024, help='Image size width') + parser.add_argument("--height", type=int, default=1024, help='Image size height') + parser.add_argument("--infer_steps", type=int, default=50, help="Inference steps") + parser.add_argument('--seed', type=int, default=42, help="A seed for all the prompts") + return parser.parse_args() + + +def infer(args): + torch.npu.set_device(args.device_id) + pipe = FluxPipeline.from_pretrained(args.path, torch_dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + + if not os.path.exists(args.save_path): + os.makedirs(args.save_path, mode=0o640) + + infer_num = 0 + time_consume = 0 + prompt_loader = PromptLoader(args.prompt_path) + for _, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + save_names = input_info['save_names'] + + print(f"[{infer_num}/{len(prompt_loader)}]: {prompts}") + infer_num += 1 + if infer_num > 3: + start_time = time.time() + + image = pipe( + prompts, + height=args.width, + width=args.height, + guidance_scale=3.5, + num_inference_steps=args.infer_steps, + max_sequence_length=512, + generator=torch.Generator().manual_seed(args.seed) + ).images[0] + + if infer_num > 3: + end_time = time.time() - start_time + time_consume += end_time + image_save_path = os.path.join(args.save_path, f"{save_names[0]}.png") + image.save(image_save_path) + + image_time_count = len(prompt_loader) - 3 + print(f"flux pipeline time is:{time_consume/image_time_count}") + return + + +if __name__ == "__main__": + inference_args = parse_arguments() + infer(inference_args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/prompts.txt b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/prompts.txt new file mode 100644 index 0000000000..a375a0bb63 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/prompts.txt @@ -0,0 +1,16 @@ +Beautiful illustration of The ocean. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Islands in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Seaports in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The waves. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Grassland. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Wheat. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Hut Tong. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The boat. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Pine trees. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Bamboo. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The temple. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Cloud in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Sun in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Spring. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Lotus. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Snow piles. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/requirements.txt b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/requirements.txt new file mode 100644 index 0000000000..7ab1879205 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/Flux.1-DEV/requirements.txt @@ -0,0 +1,9 @@ +accelerate==1.2.1 +torch==2.1.0 +torchvision==0.16.0 +ftfy +diffusers==0.32.1 +transformers==4.46.3 +tensorboard +Jinja2 +peft==0.11.1 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md b/MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md new file mode 100644 index 0000000000..831e779566 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md @@ -0,0 +1,167 @@ +## 一、准备运行环境 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.12 | - | + | torch | 2.4.0 | - | + +### 1.1 获取CANN&MindIE安装包&环境准备 +- 设备支持: +Atlas 800I A2/Atlas 800T A2设备:支持的卡数为1 +- [Atlas 800I A2/Atlas 800T A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +### 1.2 CANN安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +### 1.3 MindIE安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +### 1.4 Torch_npu安装 +安装pytorch框架 版本2.4.0 +[安装包下载](https://download.pytorch.org/whl/cpu/torch/) + +使用pip安装 +```shell +# {version}表示软件版本号,{arch}表示CPU架构。 +pip install torch-${version}-cp310-cp310-linux_${arch}.whl +``` +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```shell +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` +## 二、下载本仓库 + +### 2.1 下载到本地 +```shell + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +``` + +## 三、CogView3使用 + +### 3.1 权重及配置文件说明 +1. CogView3权重路径: +```shell +https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main +``` +- 修改该权重的model_index.json +```shell +{ + "_class_name": "CogView3PlusPipeline", + "_diffusers_version": "0.31.0", + "scheduler": [ + "cogview3plus", + "CogVideoXDDIMScheduler" + ], + "text_encoder": [ + "transformers", + "T5EncoderModel" + ], + "tokenizer": [ + "transformers", + "T5Tokenizer" + ], + "transformer": [ + "cogview3plus", + "CogView3PlusTransformer2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ] +} +``` +2. scheduler权重链接: +```shell +https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/scheduler +``` +3. text_encoder权重链接: +```shell +https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/text_encoder +``` +4. tokenizer权重链接: +```shell +https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/tokenizer +``` +5. transformer权重链接: +```shell +https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/transformer +``` +6. vae权重链接: +```shell +https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/vae +``` +7. 各模型的配置文件、权重文件的层级样例如下所示。 +```commandline +|----CogView3B +| |---- configuration.json +| |---- model_index.json +| |---- scheduler +| | |---- scheduler_config.json +| |---- text_encoder +| | |---- config.json +| | |---- 模型权重 +| |---- tokenizer +| | |---- config.json +| | |---- 模型权重 +| |---- transformer +| | |---- config.json +| | |---- 模型权重 +| |---- vae +| | |---- config.json +| | |---- 模型权重 +``` + +### 3.2 单卡单prompt功能测试 +设置权重路径 +```shell +model_path='/data/CogView3B' +``` +执行命令: +```shell +python inference_cogview3plus.py \ + --model_path ${model_path} \ + --device_id 0 \ + --width 1024 \ + --height 1024 \ + --num_inference_steps 50 \ + --dtype bf16 +``` +参数说明: +- model_path:权重路径,包含scheduler、text_encoder、tokenizer、transformer、vae,5个模型的配置文件及权重。 +- device_id:推理设备ID。 +- width:需要生成的图像的宽。 +- height: 需要生成的图像的高。 +- num_inference_steps:推理迭代步数。 +- dtype: 数据类型。目前只支持bf16。 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py new file mode 100644 index 0000000000..1139593a36 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py @@ -0,0 +1,3 @@ +from .pipeline import CogView3PlusPipeline, DiffusionPipeline +from .schedulers import CogVideoXDDIMScheduler, SchedulerMixin +from .models import CogView3PlusTransformer2DModel, ModelMixin \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/__init__.py new file mode 100644 index 0000000000..602ad432a0 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/__init__.py @@ -0,0 +1,3 @@ +from .normalization import CogView3PlusAdaLayerNormZeroTextImage, AdaLayerNormContinuous +from .embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed +from .linear import QKVLinear \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/embeddings.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/embeddings.py new file mode 100644 index 0000000000..fc2d3101eb --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/embeddings.py @@ -0,0 +1,304 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +from torch import nn +from diffusers.models.activations import get_activation + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + max_period: int = 10000, +): + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + interpolation_scale=1.0, + base_size=16, +): + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = ( + torch.arange(grid_size[0], dtype=torch.float32) + / (grid_size[0] / base_size) + / interpolation_scale + ) + grid_w = ( + torch.arange(grid_size[1], dtype=torch.float32) + / (grid_size[1] / base_size) + / interpolation_scale + ) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first + grid = torch.stack(grid, dim=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + r""" + This function generates 2D sinusoidal positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension. + grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`. + + Returns: + `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + This function generates 1D positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension `D` + pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)` + + Returns: + `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`. + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.outer(pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D) + return emb + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ + + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class CogView3CombinedTimestepSizeEmbeddings(nn.Module): + def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): + super().__init__() + + self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim) + self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward( + self, + timestep: torch.Tensor, + original_size: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + + original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1) + crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1) + target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1) + + condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1) + + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + + conditioning = timesteps_emb + condition_emb + return conditioning + + +class CogView3PlusPatchEmbed(nn.Module): + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + text_hidden_size: int = 4096, + pos_embed_max_size: int = 128, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.text_hidden_size = text_hidden_size + self.pos_embed_max_size = pos_embed_max_size + # Linear projection for image patches + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + + # Linear projection for text embeddings + self.text_proj = nn.Linear(text_hidden_size, hidden_size) + + pos_embed = get_2d_sincos_pos_embed( + hidden_size, pos_embed_max_size, base_size=pos_embed_max_size + ) + pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size) + self.register_buffer("pos_embed", pos_embed.float(), persistent=False) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + + if height % self.patch_size != 0 or width % self.patch_size != 0: + raise ValueError("Height and width must be divisible by patch size") + + height = height // self.patch_size + width = width // self.patch_size + hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous() + hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size) + + # Project the patches + hidden_states = self.proj(hidden_states) + encoder_hidden_states = self.text_proj(encoder_hidden_states) + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # Calculate text_length + text_length = encoder_hidden_states.shape[1] + + image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1) + text_pos_embed = torch.zeros( + (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device + ) + pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...] + + return (hidden_states + pos_embed).to(hidden_states.dtype) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/linear.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/linear.py new file mode 100644 index 0000000000..d242d17c2e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/linear.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class QKVLinear(nn.Module): + def __init__(self, attention_dim, hidden_size, qkv_bias=True, device=None, dtype=None): + super(QKVLinear, self).__init__() + self.attention_dim = attention_dim + self.hidden_size = hidden_size + self.qkv_bias = qkv_bias + + factory_kwargs = {"device": device, "dtype": dtype} + + self.weight = nn.Parameter(torch.empty([self.attention_dim, 3 * self.hidden_size], **factory_kwargs)) + if self.qkv_bias: + self.bias = nn.Parameter(torch.empty([3 * self.hidden_size], **factory_kwargs)) + + def forward(self, hidden_states): + + if not self.qkv_bias: + qkv = torch.matmul(hidden_states, self.weight) + else: + qkv = torch.addmm( + self.bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.weight, + beta=1, + alpha=1 + ) + + return qkv \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/normalization.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/normalization.py new file mode 100644 index 0000000000..c12b70c9b1 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/normalization.py @@ -0,0 +1,177 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers +from typing import Optional, Tuple +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False): + super().__init__() + + self.eps = eps + self.elementwise_affine = elementwise_affine + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + self.weight = None + self.bias = None + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + if bias: + self.bias = nn.Parameter(torch.zeros(dim)) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + if self.bias is not None: + hidden_states = hidden_states + self.bias + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + +@dataclass +class ChunkParam: + gate_msa: torch.Tensor + shift_mlp: torch.Tensor + scale_mlp: torch.Tensor + gate_mlp: torch.Tensor + context: torch.Tensor + c_gate_msa: torch.Tensor + c_shift_mlp: torch.Tensor + c_scale_mlp: torch.Tensor + c_gate_mlp: torch.Tensor + + +class CogView3PlusAdaLayerNormZeroTextImage(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, dim: int): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) + self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + c_shift_msa, + c_scale_msa, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + normed_x = self.norm_x(x) + normed_context = self.norm_c(context) + x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None] + context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None] + return x, ChunkParam( + gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp + ) + + +class FP32LayerNorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm( + inputs.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +class LpNorm(nn.Module): + def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12): + super().__init__() + + self.p = p + self.dim = dim + self.eps = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps) + + +class AdaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/__init__.py new file mode 100644 index 0000000000..ae8f24f59a --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/__init__.py @@ -0,0 +1,2 @@ +from .transformer_cogview3plus import CogView3PlusTransformer2DModel +from .modeling_utils import ModelMixin \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/activations.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/activations.py new file mode 100644 index 0000000000..b7d7cec29d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/activations.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate +from diffusers.utils.import_utils import is_torch_npu_available + +if is_torch_npu_available(): + import torch_npu + +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU(), +} + + +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class FP32SiLU(nn.Module): + r""" + SiLU activation function with input upcasted to torch.float32. + """ + + def __init__(self): + super().__init__() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return F.silu(inputs.float(), inplace=False).to(inputs.dtype) + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + return F.gelu(gate, approximate=self.approximate) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + return F.gelu(gate) + + def forward(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + hidden_states = self.proj(hidden_states) + if is_torch_npu_available(): + # using torch_npu.npu_geglu can run faster and save memory on NPU. + return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] + else: + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class SwiGLU(nn.Module): + r""" + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.activation = nn.SiLU() + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.activation(gate) + + +class ApproximateGELU(nn.Module): + r""" + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class LinearActivation(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"): + super().__init__() + + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.activation = get_activation(activation) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + return self.activation(hidden_states) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention.py new file mode 100644 index 0000000000..946d829c6c --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention.py @@ -0,0 +1,87 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from torch import nn + +from diffusers.utils import deprecate, logging +from .activations import GEGLU, GELU, ApproximateGELU, LinearActivation, SwiGLU + + +logger = logging.get_logger(__name__) + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "swiglu": + act_fn = SwiGLU(dim, inner_dim, bias=bias) + elif activation_fn == "linear-silu": + act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention_processor.py new file mode 100644 index 0000000000..d36e9265a3 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention_processor.py @@ -0,0 +1,348 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +import torch_npu + +from diffusers.utils import logging +from diffusers.utils.torch_utils import maybe_allow_in_graph + +from ..layers import QKVLinear + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + out_context_dim: int = None, + context_pre_only=None, + pre_only=False, + elementwise_affine: bool = True, + is_causal: bool = False, + ): + super().__init__() + + # To prevent circular import. + from ..layers.normalization import FP32LayerNorm, LpNorm, RMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.is_causal = is_causal + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "layer_norm_across_heads": + # Lumina applies qk norm across all heads + self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) + self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim_head * heads, eps=eps) + self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "l2": + self.norm_q = LpNorm(p=2, dim=-1, eps=eps) + self.norm_k = LpNorm(p=2, dim=-1, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_qkv = QKVLinear(self.inner_dim, query_dim) + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) + else: + self.norm_added_q = None + self.norm_added_k = None + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + +class CogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + B, S, _ = hidden_states.shape + qkv = attn.to_qkv(hidden_states) + inner_dim = qkv.shape[-1] // 3 + head_dim = inner_dim // attn.heads + qkv_shape = (B, S, 3, attn.heads, head_dim) + query, key, value = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4).contiguous().unbind(0) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + B, N, S, D = query.shape + dim = 48 + pad_shape = [B, N, S, D] + pad_shape[-1] = dim - pad_shape[-1] + pad = torch.zeros(pad_shape, dtype=query.dtype, device=query.device) + query = torch.cat([query, pad], dim=-1) + key = torch.cat([key, pad], dim=-1) + value = torch.cat([value, pad], dim=-1) + hidden_states = torch_npu.npu_prompt_flash_attention( + query, + key, + value, + input_layout='BNSD', + scale_value=D**-0.5, + pre_tokens=65535, + next_tokens=65535, + num_heads=N + ) + hidden_states = hidden_states[:, :, :, :D] + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/model_load_utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/model_load_utils.py new file mode 100644 index 0000000000..1257aad309 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/model_load_utils.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import safetensors.torch + + +SAFETENSORS_EXTENSION = "safetensors" +EMA_STATE_DICT = "ema_state_dict" +STATE_DICT = "state_dict" +CPU = "cpu" + + +def load_state_dict_sd(model_path): + name = os.path.basename(model_path).split('.')[-1] # get weights name + if name.endswith("ckpt"): + weight = torch.load(model_path, map_location=CPU) + if (EMA_STATE_DICT in weight): + weight = weight[EMA_STATE_DICT] + weight = {key.replace("module.", ""): value for key, value in weight.items()} + elif STATE_DICT in weight: + weight = weight[STATE_DICT] + return weight + elif name == SAFETENSORS_EXTENSION: # diffuser model use same name + return safetensors.torch.load_file(model_path, device=CPU) # first load on cpu + else: + # to support hf shard model weights + return torch.load(model_path, map_location=CPU) # first load on cpu \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/modeling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/modeling_utils.py new file mode 100644 index 0000000000..fddf0ade3f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/modeling_utils.py @@ -0,0 +1,771 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import json +import os +import re +from collections import OrderedDict +from functools import wraps +from typing import Any, List, Optional, Tuple, Union + +import torch +from huggingface_hub.utils import validate_hf_hub_args +from torch import Tensor, nn + +from diffusers import __version__ +from diffusers.quantizers import DiffusersAutoQuantizer +from diffusers.quantizers.quantization_config import QuantizationMethod +from diffusers.utils import ( + CONFIG_NAME, + FLAX_WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_checkpoint_shard_files, + _get_model_file, + deprecate, + is_accelerate_available, + is_bitsandbytes_version, + logging, +) +from diffusers.utils.hub_utils import PushToHubMixin +from diffusers.models.model_loading_utils import ( + _fetch_index_file, + _fetch_index_file_legacy, + _load_state_dict_into_model, + _merge_sharded_checkpoints, + load_model_dict_into_meta, + load_state_dict, +) + + +logger = logging.get_logger(__name__) + + +_LOW_CPU_MEM_USAGE_DEFAULT = True + + +if is_accelerate_available(): + import accelerate + + +def get_parameter_device(parameter: torch.nn.Module) -> torch.device: + try: + parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) + return next(parameters_and_buffers).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: + """ + Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. + """ + last_dtype = None + for param in parameter.parameters(): + last_dtype = param.dtype + if param.is_floating_point(): + return param.dtype + + for buffer in parameter.buffers(): + last_dtype = buffer.dtype + if buffer.is_floating_point(): + return buffer.dtype + + if last_dtype is not None: + # if no floating dtype was found return whatever the first dtype is + return last_dtype + + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + last_tuple = None + for current_tuple in gen: + last_tuple = current_tuple + if current_tuple[1].is_floating_point(): + return current_tuple[1].dtype + + if last_tuple is not None: + # fallback to the last dtype + return last_tuple[1].dtype + + +class ModelMixin(torch.nn.Module, PushToHubMixin): + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False + _keys_to_ignore_on_load_unexpected = None + _no_split_modules = None + _keep_in_fp32_modules = None + + def __init__(self): + super().__init__() + + def __getattr__(self, name: str) -> Any: + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) + return self._internal_dict[name] + + return super().__getattr__(name) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + cache_dir = kwargs.pop("cache_dir", None) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + quantization_config = kwargs.pop("quantization_config", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + if isinstance(device_map, torch.device): + device_map = {"": device_map} + elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + try: + device_map = {"": torch.device(device_map)} + except RuntimeError as e: + raise ValueError( + "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " + f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." + ) from e + elif isinstance(device_map, int): + if device_map < 0: + raise ValueError( + "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " + ) + else: + device_map = {"": device_map} + + if device_map is not None: + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + elif not low_cpu_mem_usage: + raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + **kwargs, + ) + # no in-place modification of the original config. + config = copy.deepcopy(config) + + # determine initial quantization config. + ####################################### + pre_quantized = "quantization_config" in config and config["quantization_config"] is not None + if pre_quantized or quantization_config is not None: + if pre_quantized: + config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs( + config["quantization_config"], quantization_config + ) + else: + config["quantization_config"] = quantization_config + hf_quantizer = DiffusersAutoQuantizer.from_config( + config["quantization_config"], pre_quantized=pre_quantized + ) + else: + hf_quantizer = None + + if hf_quantizer is not None: + is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" + if is_bnb_quantization_method and device_map is not None: + raise NotImplementedError( + "Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future." + ) + + hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) + torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + + # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` + user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value + + # Force-set to `True` for more mem efficiency + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.") + elif not low_cpu_mem_usage: + raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.") + + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] + + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.") + elif not low_cpu_mem_usage: + raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.") + else: + keep_in_fp32_modules = [] + ####################################### + + # Determine if we're loading from a directory of sharded checkpoints. + is_sharded = False + index_file = None + is_local = os.path.isdir(pretrained_model_name_or_path) + index_file_kwargs = { + "is_local": is_local, + "pretrained_model_name_or_path": pretrained_model_name_or_path, + "subfolder": subfolder or "", + "use_safetensors": use_safetensors, + "cache_dir": cache_dir, + "variant": variant, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "revision": revision, + "user_agent": user_agent, + "commit_hash": commit_hash, + } + index_file = _fetch_index_file(**index_file_kwargs) + # In case the index file was not found we still have to consider the legacy format. + # this becomes applicable when the variant is not None. + if variant is not None and (index_file is None or not os.path.exists(index_file)): + index_file = _fetch_index_file_legacy(**index_file_kwargs) + if index_file is not None and index_file.is_file(): + is_sharded = True + + if is_sharded and from_flax: + raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") + + # load model + model_file = None + if from_flax: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + model = cls.from_config(config, **unused_kwargs) + + # Convert the weights + from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, model_file) + else: + if is_sharded: + sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + ) + if hf_quantizer is not None and is_bnb_quantization_method: + model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) + logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") + is_sharded = False + + elif use_safetensors and not is_sharded: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") + if not allow_pickle: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if model_file is None and not is_sharded: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **unused_kwargs) + + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + ) + + # if device_map is None, load the state dict and move the params from meta device to the cpu + if device_map is None and not is_sharded: + # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. + # It would error out during the `validate_environment()` call above in the absence of cuda. + if hf_quantizer is None: + param_device = "cpu" + else: + param_device = torch.device(torch.cuda.current_device()) + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if hf_quantizer is not None: + missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_name_or_path, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + else: + weights_path = index_file + with open(index_file) as f: + index = json.loads(f.read()) + if "weight_map" in index: + index = index["weight_map"] + weights_path = sorted(list(set(index.values()))) + weights_path = [os.path.join(pretrained_model_name_or_path, f) for f in weights_path] + + model = cls._load_model(model, weights_path, is_sharded) + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + model = cls.from_config(config, **unused_kwargs) + + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will + # completely lose the effectivity of `use_keep_in_fp32_modules`. + elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules: + model = model.to(torch_dtype) + + if hf_quantizer is not None: + # We also make sure to purge `_pre_quantization_dtype` when we serialize + # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable. + model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype) + else: + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_model(cls, model, weights_path, is_sharded): + if not is_sharded: + state_dict = load_state_dict(weights_path) + model.load_weights(state_dict) + else: + need_key = set(model.state_dict().keys()) + state_dict = {} + cache = {} + for weight_file in weights_path: + state_dict = load_state_dict(weight_file) + state_dict.update(cache) + loadkey_cache = model.load_weights(state_dict, is_sharded) + if loadkey_cache : + if isinstance(loadkey_cache, tuple): + loaded_keys, cache = loadkey_cache + else: + loaded_keys = loadkey_cache + need_key = need_key.symmetric_difference(set(loaded_keys)) + + if len(need_key) > 0: + raise ValueError(f"The weight miss key: {need_key}") + return model + + def load_weights(self, state_dict, shard=False): + with torch.no_grad(): + if not shard: + self.load_state_dict(state_dict) + return {} + else: + self.load_state_dict(state_dict, strict=False, assign=True) + return state_dict.keys() + + # Adapted from `transformers`. + @wraps(torch.nn.Module.cuda) + def cuda(self, *args, **kwargs): + # Checks if the model has been loaded in 4-bit or 8-bit with BNB + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "Calling `cuda()` is not supported for `8-bit` quantized models. " + " Please use the model as it is, since the model has already been set to the correct devices." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + return super().cuda(*args, **kwargs) + + # Adapted from `transformers`. + @wraps(torch.nn.Module.to) + def to(self, *args, **kwargs): + dtype_present_in_args = "dtype" in kwargs + + if not dtype_present_in_args: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + + if getattr(self, "is_quantized", False): + if dtype_present_in_args: + raise ValueError( + "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please " + "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`" + ) + + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + return super().to(*args, **kwargs) + + # Taken from `transformers`. + def half(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.half()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been cast to the correct `dtype`." + ) + else: + return super().half(*args) + + # Taken from `transformers`. + def float(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.float()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been cast to the correct `dtype`." + ) + else: + return super().float(*args) + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict: OrderedDict, + pretrained_model_name_or_path: Union[str, os.PathLike], + ignore_mismatched_sizes: bool = False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @property + def device(self) -> torch.device: + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + return get_parameter_dtype(self) + + def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: + deprecated_attention_block_paths = [] + + def recursive_find_attn_block(name, module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_paths.append(name) + + for sub_name, sub_module in module.named_children(): + sub_name = sub_name if name == "" else f"{name}.{sub_name}" + recursive_find_attn_block(sub_name, sub_module) + + recursive_find_attn_block("", self) + + for path in deprecated_attention_block_paths: + # group_norm path stays the same + + # query -> to_q + if f"{path}.query.weight" in state_dict: + state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") + if f"{path}.query.bias" in state_dict: + state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") + + # key -> to_k + if f"{path}.key.weight" in state_dict: + state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") + if f"{path}.key.bias" in state_dict: + state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") + + # value -> to_v + if f"{path}.value.weight" in state_dict: + state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") + if f"{path}.value.bias" in state_dict: + state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") + + # proj_attn -> to_out.0 + if f"{path}.proj_attn.weight" in state_dict: + state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") + if f"{path}.proj_attn.bias" in state_dict: + state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py new file mode 100644 index 0000000000..37c5961586 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py @@ -0,0 +1,397 @@ +# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Union + +import torch +import torch.nn as nn +import numpy as np + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.utils import logging +from diffusers.models.modeling_outputs import Transformer2DModelOutput + +from .modeling_utils import ModelMixin +from .attention import FeedForward +from .attention_processor import CogVideoXAttnProcessor2_0, Attention +from ..layers import CogView3PlusAdaLayerNormZeroTextImage, AdaLayerNormContinuous +from ..layers import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CogView3PlusTransformerBlock(nn.Module): + def __init__( + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + ): + super().__init__() + + self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + out_dim=dim, + bias=True, + qk_norm="layer_norm", + elementwise_affine=False, + eps=1e-6, + processor=CogVideoXAttnProcessor2_0(), + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + emb: torch.Tensor, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate + norm_hidden_states, chunk_params = self.norm1(hidden_states, encoder_hidden_states, emb) + + gate_msa = chunk_params.gate_msa + shift_mlp = chunk_params.shift_mlp + scale_mlp = chunk_params.scale_mlp + gate_mlp = chunk_params.gate_mlp + norm_encoder_hidden_states = chunk_params.context + c_gate_msa = chunk_params.c_gate_msa + c_shift_mlp = chunk_params.c_shift_mlp + c_scale_mlp = chunk_params.c_scale_mlp + c_gate_mlp = chunk_params.c_gate_mlp + + # attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] + + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + return hidden_states, encoder_hidden_states + + +class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 30, + attention_head_dim: int = 40, + num_attention_heads: int = 64, + out_channels: int = 16, + text_embed_dim: int = 4096, + time_embed_dim: int = 512, + condition_dim: int = 256, + pos_embed_max_size: int = 128, + use_cache: bool = True, + cache_interval: int = 2, + cache_start: int = 3, + num_cache_layer: int = 13, + cache_start_steps: int = 5, + ): + super().__init__() + self.out_channels = out_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.num_layers = num_layers + + # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords + # Each of these are sincos embeddings of shape 2 * condition_dim + self.pooled_projection_dim = 3 * 2 * condition_dim + + self.patch_embed = CogView3PlusPatchEmbed( + in_channels=in_channels, + hidden_size=self.inner_dim, + patch_size=patch_size, + text_hidden_size=text_embed_dim, + pos_embed_max_size=pos_embed_max_size, + ) + + self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings( + embedding_dim=time_embed_dim, + condition_dim=condition_dim, + pooled_projection_dim=self.pooled_projection_dim, + timesteps_dim=self.inner_dim, + ) + + self.transformer_blocks = nn.ModuleList( + [ + CogView3PlusTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + embedding_dim=self.inner_dim, + conditioning_embedding_dim=time_embed_dim, + elementwise_affine=False, + eps=1e-6, + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + self.q_weight_cache = None + self.q_bias_cache = None + self.k_weight_cache = None + self.k_bias_cache = None + self.v_weight_cache = None + self.v_bias_cache = None + + self.use_cache = use_cache + self.cache_interval = cache_interval + self.cache_start = cache_start + self.num_cache_layer = num_cache_layer + self.cache_start_steps = cache_start_steps + + self.delta_cache = None + self.delta_encoder_cache = None + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + states, + timestep: torch.LongTensor, + original_size: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + hidden_states = states[0] + encoder_hidden_states = states[1] + height, width = hidden_states.shape[-2:] + text_seq_length = encoder_hidden_states.shape[1] + + hidden_states = self.patch_embed( + hidden_states, encoder_hidden_states + ) # takes care of adding positional embeddings too. + emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype) + + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + hidden_states, encoder_hidden_states = self._forward_blocks(hidden_states, encoder_hidden_states, emb, states[2]) + + hidden_states = self.norm_out(hidden_states, emb) + hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels) + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size) + ) + hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + return Transformer2DModelOutput(sample=output) + + # forward blocks in range [start_idx, end_idx), then return input and output + def _forward_blocks_range(self, hidden_states, encoder_hidden_states, emb, start_idx, end_idx, **kwargs): + for _, block in enumerate(self.transformer_blocks[start_idx: end_idx]): + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=emb, + ) + + return hidden_states, encoder_hidden_states + + def _forward_blocks(self, hidden_states, encoder_hidden_states, emb, t_idx): + num_blocks = len(self.transformer_blocks) + + if not self.use_cache or (t_idx < self.cache_start_steps): + hidden_states, encoder_hidden_states = self._forward_blocks_range( + hidden_states, + encoder_hidden_states, + emb, + 0, + num_blocks + ) + else: + # infer [0, cache_start) + hidden_states, encoder_hidden_states = self._forward_blocks_range( + hidden_states, + encoder_hidden_states, + emb, + 0, + self.cache_start + ) + # infer [cache_start, cache_end) + cache_end = np.minimum(self.cache_start + self.num_cache_layer, num_blocks) + hidden_states_before_cache = hidden_states.clone() + encoder_hidden_states_before_cache = encoder_hidden_states.clone() + if t_idx % self.cache_interval == (self.cache_start_steps % self.cache_interval): + hidden_states, encoder_hidden_states = self._forward_blocks_range( + hidden_states, + encoder_hidden_states, + emb, + self.cache_start, + cache_end + ) + self.delta_cache = hidden_states - hidden_states_before_cache + self.delta_encoder_cache = encoder_hidden_states - encoder_hidden_states_before_cache + else: + hidden_states = hidden_states_before_cache + self.delta_cache + encoder_hidden_states = encoder_hidden_states_before_cache + self.delta_encoder_cache + # infer [cache_end, num_blocks) + hidden_states, encoder_hidden_states = self._forward_blocks_range( + hidden_states, + encoder_hidden_states, + emb, + cache_end, + num_blocks + ) + + return hidden_states, encoder_hidden_states + + def load_weights(self, state_dict, shard=False): + with torch.no_grad(): + if not shard: + self.load_state_dict(state_dict) + return {} + else: + weights = state_dict + + for i in range(self.num_layers): + if i != 26: + q_weight = weights.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None) + q_bias = weights.pop(f"transformer_blocks.{i}.attn1.to_q.bias", None) + k_weight = weights.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None) + k_bias = weights.pop(f"transformer_blocks.{i}.attn1.to_k.bias", None) + v_weight = weights.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None) + v_bias = weights.pop(f"transformer_blocks.{i}.attn1.to_v.bias", None) + + # query, key, value的weight和bias权重存在同一个文件中,不会分开存储。 + if q_weight is not None and k_weight is not None and v_weight is not None: + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0).transpose(0, 1).contiguous() + qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0).contiguous() + weights[f"transformer_blocks.{i}.attn1.to_qkv.weight"] = qkv_weight + weights[f"transformer_blocks.{i}.attn1.to_qkv.bias"] = qkv_bias + else: + if self.q_weight_cache is None: + self.q_weight_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None) + if self.q_bias_cache is None: + self.q_bias_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_q.bias", None) + if self.k_weight_cache is None: + self.k_weight_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None) + if self.k_bias_cache is None: + self.k_bias_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_k.bias", None) + if self.v_weight_cache is None: + self.v_weight_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None) + if self.v_bias_cache is None: + self.v_bias_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_v.bias", None) + + qk_weight_cache = self.q_weight_cache is not None and self.k_weight_cache is not None + if qk_weight_cache and self.v_weight_cache is not None: + qkv_weight = torch.cat( + [self.q_weight_cache, self.k_weight_cache, self.v_weight_cache], + dim=0 + ).transpose(0, 1).contiguous() + qkv_bias = torch.cat([self.q_bias_cache, self.k_bias_cache, self.v_bias_cache], dim=0).contiguous() + weights[f"transformer_blocks.26.attn1.to_qkv.weight"] = qkv_weight + weights[f"transformer_blocks.26.attn1.to_qkv.bias"] = qkv_bias + + self.load_state_dict(weights, strict=False, assign=True) + return weights.keys() diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/__init__.py new file mode 100644 index 0000000000..626e0d588b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/__init__.py @@ -0,0 +1 @@ +from .pipeline_cogview3plus import CogView3PlusPipeline, DiffusionPipeline \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py new file mode 100644 index 0000000000..4b07df76a6 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py @@ -0,0 +1,339 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers import AutoencoderKL + +from ..models import CogView3PlusTransformer2DModel +from ..schedulers import CogVideoXDDIMScheduler +from .pipeline_output import CogView3PipelineOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView3PlusPipeline(DiffusionPipeline): + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: CogView3PlusTransformer2DModel, + scheduler: CogVideoXDDIMScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 224, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, image_size, dtype, device): + height = image_size[0] + width = image_size[1] + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + latents = randn_tensor(shape, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_extra_step_kwargs(self, generator, eta): + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image_size: Tuple[int, int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + ) -> Union[CogView3PipelineOutput, Tuple]: + if image_size is None: + height = self.transformer.config.sample_size * self.vae_scale_factor + width = self.transformer.config.sample_size * self.vae_scale_factor + else: + height = image_size[0] + width = image_size[1] + + original_size = (height, width) + target_size = (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=224, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + (height, width), + prompt_embeds.dtype, + device, + ) + + extra_step_kwargs = self.prepare_extra_step_kwargs(None, 0.0) + + # 7. Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype) + crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + original_size = torch.cat([original_size, original_size]) + target_size = torch.cat([target_size, target_size]) + crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left]) + + original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + states=(latent_model_input, prompt_embeds, i), + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = latents.to(prompt_embeds.dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=None)[0] + image = self.image_processor.postprocess(image, output_type="pil") + + # Offload all models + self.maybe_free_model_hooks() + + return CogView3PipelineOutput(images=image) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_output.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_output.py new file mode 100644 index 0000000000..e56a4485d7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from diffusers.utils import BaseOutput + + +@dataclass +class CogView3PipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/__init__.py new file mode 100644 index 0000000000..f98b6e1dec --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/__init__.py @@ -0,0 +1,2 @@ +from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler +from .scheduling_utils import SchedulerMixin \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_ddim_cogvideox.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_ddim_cogvideox.py new file mode 100644 index 0000000000..b4f81796e9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_ddim_cogvideox.py @@ -0,0 +1,276 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +class DDIMSchedulerOutput(BaseOutput): + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + + return alphas_bar + + +class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.0120, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + set_alpha_to_one: bool = True, + rescale_betas_zero_snr: bool = False, + snr_shift_scale: float = 3.0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # Modify: SNR shift following SD3 + self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) + + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 + b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t + + prev_sample = a_t * sample + b_t * pred_original_sample + + if not return_dict: + return ( + prev_sample, + pred_original_sample, + ) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_utils.py new file mode 100644 index 0000000000..eeb6e77dee --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_utils.py @@ -0,0 +1,113 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + +import torch +from huggingface_hub.utils import validate_hf_hub_args + +from diffusers.utils import BaseOutput, PushToHubMixin + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +class KarrasDiffusionSchedulers(Enum): + DDIMScheduler = 1 + DDPMScheduler = 2 + PNDMScheduler = 3 + LMSDiscreteScheduler = 4 + EulerDiscreteScheduler = 5 + HeunDiscreteScheduler = 6 + EulerAncestralDiscreteScheduler = 7 + DPMSolverMultistepScheduler = 8 + DPMSolverSinglestepScheduler = 9 + KDPM2DiscreteScheduler = 10 + KDPM2AncestralDiscreteScheduler = 11 + DEISMultistepScheduler = 12 + UniPCMultistepScheduler = 13 + DPMSolverSDEScheduler = 14 + EDMEulerScheduler = 15 + + +AysSchedules = { + "StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24], + "StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0], + "StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13], + "StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0], + "StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0], +} + + +@dataclass +class SchedulerOutput(BaseOutput): + """ + Base class for the output of a scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.Tensor + + +class SchedulerMixin(PushToHubMixin): + + config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + + config, kwargs, _ = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + return_commit_hash=True, + **kwargs, + ) + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/vae/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/vae/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py new file mode 100644 index 0000000000..c3bb1f2ebb --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import time + +import torch + +from cogview3plus import CogView3PlusPipeline + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Generate an image using the CogView3-Plus-3B model.") + + # Define arguments for prompt, model path, etc. + parser.add_argument( + "--prompt", + type=list, + default=[ + "A vibrant cherry red sports car sits proudly under the gleaming sun, \ + its polished exterior smooth and flawless, casting a mirror-like reflection. \ + The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, \ + and a set of black, high-gloss racing rims that contrast starkly with the red. \ + A subtle hint of chrome embellishes the grille and exhaust, \ + while the tinted windows suggest a luxurious and private interior. \ + he scene conveys a sense of speed and elegance, \ + the car appearing as if it's about to burst into a sprint along a coastal road, \ + with the ocean's azure waves crashing in the background." + ], + help="The text description for generating the image." + ) + parser.add_argument( + "--model_path", type=str, default="/data/CogView3B", help="Path to the pre-trained model." + ) + parser.add_argument( + "--guidance_scale", type=float, default=7.0, help="The guidance scale for classifier-free guidance." + ) + parser.add_argument( + "--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt." + ) + parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of denoising steps for inference.") + parser.add_argument("--width", type=int, default=1024, help="Width of the generated image.") + parser.add_argument("--height", type=int, default=1024, help="Height of the generated image.") + parser.add_argument("--output_path", type=str, default="cogview3.png", help="Path to save the generated image.") + parser.add_argument("--dtype", type=str, default="bf16", help="bf16 or fp16") + parser.add_argument("--device_id", type=int, default=7, help="NPU device id") + + return parser.parse_args() + + +def infer(args): + torch.npu.set_device(args.device_id) + dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 + + # Load the pre-trained model with the specified precision + pipe = CogView3PlusPipeline.from_pretrained(args.model_path, torch_dtype=dtype).to("npu") + + use_time = 0 + loops = 5 + for i in range(loops): + start_time = time.time() + # Generate the image based on the prompt + image = pipe( + prompt=args.prompt[0], + guidance_scale=args.guidance_scale, + num_images_per_prompt=args.num_images_per_prompt, + num_inference_steps=args.num_inference_steps, + image_size=(args.height, args.width), + ).images[0] + + if i >= 2: + use_time += time.time() - start_time + logger.info("current_time is %.3f )", time.time() - start_time) + + torch.npu.empty_cache() + + logger.info("use_time is %.3f)", use_time / 3) + + # Save the generated image to the local file system + image.save(args.output_path) + + print(f"Image saved to {args.output_path}") + + +if __name__ == "__main__": + inference_args = parse_arguments() + infer(inference_args) + diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt b/MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt new file mode 100644 index 0000000000..1600434700 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt @@ -0,0 +1,8 @@ +deepspeed==0.16.1 +transformers==4.47.1 +gradio==5.9.1 +accelerate==1.0.1 +diffusers==0.31.0 +sentencepiece==0.2.0 +torch==2.4.0 +openai==1.58.1 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/README.md b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/README.md new file mode 100644 index 0000000000..660490c189 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/README.md @@ -0,0 +1,411 @@ +## 一、准备运行环境 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | + +### 1.1 获取CANN&MindIE安装包&环境准备 +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1 +- [Atlas 800I A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +### 1.2 CANN安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +### 1.3 MindIE安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +### 1.4 Torch_npu安装 +安装pytorch框架 版本2.1.0 +[安装包下载](https://download.pytorch.org/whl/cpu/torch/) + +使用pip安装 +```shell +# {version}表示软件版本号,{arch}表示CPU架构。 +pip install torch-${version}-cp310-cp310-linux_${arch}.whl +``` +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```shell +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` + +### 1.5 安装mindspeed +```shell +# 下载mindspeed源码仓 +git clone https://gitee.com/ascend/MindSpeed.git +# 使用pip安装 +pip install -e MindSpeed +``` + +## 二、下载本仓库 + +### 2.1 下载到本地 +```shell +git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +``` + +### 2.2 安装依赖 +使用pip安装 +```shell +pip install -r requirents.txt +``` +若要使用hpsv2验证精度,则还需要按照以下步骤安装hpsv2 +```shell +git clone https://github.com/tgxs002/HPSv2.git +pip install -e HPSv2 +``` + +## 三、HunyuanDiT使用 + +### 3.1 模型权重及配置文件说明 +1. 权重链接: +```shell +https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2/tree/main/t2i +``` +- 在t2i/model路径下,新增HunyuanDiT模型权重的配置文件,命名为config.json +```shell +{ + "_class_name": "HunyuanDiT2DModel", + "_mindiesd_version": "2.0.RC1", + "input_size": [ + null, + null + ], + "patch_size": 2, + "in_channels": 4, + "hidden_size": 1408, + "depth": 40, + "num_heads": 16, + "mlp_ratio": 4.3637, + "text_states_dim": 1024, + "text_states_dim_t5": 2048, + "text_len": 77, + "text_len_t5": 256, + "size_cond": null, + "use_style_cond": false +} +``` +2. 各模型的配置文件、权重文件的路径层级样例如下所示。 +```commandline +|----hunyuan_dit +| |---- ckpts +| | |---- t2i +| | | |---- clip_text_encoder +| | | |---- model +| | | | |---- config.json +| | | | |---- 模型权重 +| | | |---- mt5 +| | | |---- sdxl-vae-fp16-fix +| | | |---- tokenizer +``` + +### 3.2 模型单卡推理适配的测试 +设置权重路径 +```shell +path="ckpts/t2i" +``` +修改权重文件夹权限为安全权限 +```shell +chmod -R 640 ckpts/t2i/ +``` +执行命令: +```shell +python inference_hydit.py \ + --path ${path} \ + --device_id 0 \ + --prompt "渔舟唱晚" \ + --input_size 1024 1024 \ + --seed 42 \ + --infer_steps 100 +``` +参数说明: +- path:权重路径,包含clip_text_encoder、model、mt5、sdxl-vae-fp16-fix、tokenizer的权重及配置文件。 +- device_id:推理设备ID。 +- prompt:用于图像生成的文字描述提示。 +- input_size:生成的图像尺寸,宽高要求是8的倍数。 +- seed:设置随机种子,默认值为42。 +- infer_steps:推理迭代步数,默认值为100。 + +执行完成后在`results`目录下生成一张推理图像。 + +### 3.3 模型单卡等价优化的性能测试 +设置权重路径 +```shell +path="ckpts/hydit" +``` +修改权重文件夹权限为安全权限 +```shell +chmod -R 640 ckpts/t2i/ +``` +执行命令: +```shell +python inference_hydit.py \ + --path ${path} \ + --device_id 0 \ + --test_acc \ + --prompt_file "prompts/example_prompts.txt" \ + --input_size 1024 1024 \ + --seed 42 \ + --infer_steps 100 +``` +参数说明: +- path:权重路径,包含clip_text_encoder、model、mt5、sdxl-vae-fp16-fix、tokenizer的权重及配置文件。 +- device_id:推理设备ID。 +- test_acc:使用 --test_acc 开启prompt_file列表中的图像生成,用于性能/精度测试。 +- prompt_file:用于图像生成的文字描述提示的列表文件路径。 +- input_size:生成的图像尺寸,宽高要求是8的倍数。 +- seed:设置随机种子,默认值为42。 +- infer_steps:推理迭代步数,默认值为100。 + +执行完成后在`results`目录下生成推理图像,图像生成顺序与prompt顺序保持一致,并在终端显示推理时间。 + +### 3.4 模型单卡算法优化的性能测试 +设置权重路径 +```shell +path="ckpts/hydit" +``` +修改权重文件夹权限为安全权限 +```shell +chmod -R 640 ckpts/t2i/ +``` +执行命令: +```shell +python inference_hydit.py \ + --path ${path} \ + --device_id 0 \ + --test_acc \ + --prompt_file "prompts/example_prompts.txt" \ + --use_cache \ + --input_size 1024 1024 \ + --seed 42 \ + --infer_steps 100 +``` +参数说明: +- path:权重路径,包含clip_text_encoder、model、mt5、sdxl-vae-fp16-fix、tokenizer的权重及配置文件。 +- device_id:推理设备ID。 +- test_acc:使用 --test_acc 开启prompt_file列表中的图像生成,用于性能/精度测试。 +- prompt_file:用于图像生成的文字描述提示的列表文件路径。 +- use_cache:使用 --use_cache 开启算法策略优化的测试。 +- input_size:生成的图像尺寸,宽高要求是8的倍数。 +- seed:设置随机种子,默认值为42。 +- infer_steps:推理迭代步数,默认值为100。 + +执行完成后在`results`目录下生成推理图像,图像生成顺序与prompt顺序保持一致,并在终端显示推理时间。 + +### 3.5 模型单卡多batch推理适配测试 +设置权重路径 +```shell +path="ckpts/hydit" +``` +修改权重文件夹权限为安全权限 +```shell +chmod -R 640 ckpts/t2i/ +``` +执行命令: +```shell +python inference_hydit.py \ + --path ${path} \ + --device_id 0 \ + --test_acc \ + --prompt_file "prompts/example_prompts.txt" \ + --use_cache \ + --input_size 1024 1024 \ + --batch_size 2 \ + --seed 42 \ + --infer_steps 100 +``` +参数说明: +- path:权重路径,包含clip_text_encoder、model、mt5、sdxl-vae-fp16-fix、tokenizer的权重及配置文件。 +- device_id:推理设备ID。 +- test_acc:使用 --test_acc 开启prompt_file列表中的图像生成,用于性能/精度测试。 +- prompt_file:用于图像生成的文字描述提示的列表文件路径。 +- use_cache:使用 --use_cache 开启算法策略优化的测试。 +- input_size:生成的图像尺寸,宽高要求是8的倍数。 +- batch_size:每个prompt生成的图像数量,根据设备显存,batch_size最大设置为2。 +- seed:设置随机种子,默认值为42。 +- infer_steps:推理迭代步数,默认值为100。 + +执行完成后在`results`目录下生成推理图像,图像生成顺序与prompt顺序保持一致,并在终端显示推理时间。 + +## 四、精度验证 +由于生成的图像存在随机性,提供两种精度验证方法: +1. CLIP-score(文图匹配度量):评估图像和输入文本的相关性,分数的取值范围为[-1, 1],越高越好。使用Parti数据集进行验证。 +2. HPSv2(图像美学度量):评估生成图像的人类偏好评分,分数的取值范围为[0, 1],越高越好。使用HPSv2数据集进行验证 + +【注意】由于要生成的图像数量较多,进行完整的精度验证需要耗费很长的时间。 + +### 4.1 下载Parti数据集和hpsv2数据集 +```shell +# 下载Parti数据集 +wget https://raw.githubusercontent.com/google-research/parti/main/PartiPrompts.tsv --no-check-certificate +``` +hpsv2数据集下载链接:https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/hpsv2_benchmark_prompts.json + +建议将`PartiPrompts.tsv`和`hpsv2_benchmark_prompts.json`文件放到`prompts/`路径下。 + +### 4.2 下载模型权重 +```shell +# Clip Score和HPSv2均需要使用的权重 +GIT_LFS_SKIP_SMUDGE=1 +git clone https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K +cd ./CLIP-ViT-H-14-laion2B-s32B-b79K +# HPSv2权重 +wget https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt --no-check-certificate +``` +也可手动下载[Clip Score权重](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/open_clip_pytorch_model.bin),将权重放到`CLIP-ViT-H-14-laion2B-s32B-b79K`目录下,手动下载[HPSv2权重](https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt)放到当前路径。 + +### 4.3 使用推理脚本读取Parti数据集,生成图像 +设置权重路径 +```shell +path="ckpts/hydit" +``` +修改权重文件夹权限为安全权限 +```shell +chmod -R 640 ckpts/t2i/ +``` +执行命令: +```shell +# 使用算法优化 +python inference_hydit.py \ + --path ${path} \ + --device_id 0 \ + --test_acc \ + --prompt_file "prompts/PartiPrompts.tsv" \ + --prompt_file_type parti \ + --max_num_prompts 0 \ + --info_file_save_path ./image_info_parti.json \ + --save_result_path ./results_parti \ + --use_cache \ + --input_size 1024 1024 \ + --seed 42 \ + --infer_steps 100 +``` +参数说明: +- path:权重路径,包含clip_text_encoder、model、mt5、sdxl-vae-fp16-fix、tokenizer的权重及配置文件。 +- device_id:推理设备ID。 +- test_acc:使用 --test_acc 开启prompt_file列表中的图像生成,用于性能/精度测试。 +- prompt_file:用于图像生成的文字描述提示的列表文件路径。 +- prompt_file_type:prompt文件类型,用于指定读取方式,可选范围:plain,parti,hpsv2。默认值为plain。 +- max_num_prompts:限制prompt数量为前X个,0表示不限制。 +- info_file_save_path:生成图像信息的json文件路径。 +- save_result_path:生成图像的存放目录。 +- use_cache:使用 --use_cache 开启算法策略优化的测试。 +- input_size:生成的图像尺寸,宽高要求是8的倍数。 +- seed:设置随机种子,默认值为42。 +- infer_steps:推理迭代步数,默认值为100。 + +执行完成后在`./results_parti`目录下生成推理图像。在当前目录下生成一个`image_info_parti.json`文件,记录着图像和prompt的对应关系,并在终端显示推理时间。 + +### 4.4 使用推理脚本读取hpsv2数据集,生成图像 +设置权重路径 +```shell +path="ckpts/hydit" +``` +修改权重文件夹权限为安全权限 +```shell +chmod -R 640 ckpts/t2i/ +``` +执行命令: +```shell +# 使用算法优化 +python inference_hydit.py \ + --path ${path} \ + --device_id 0 \ + --test_acc \ + --prompt_file "prompts/hpsv2_benchmark_prompts.json" \ + --prompt_file_type hpsv2 \ + --max_num_prompts 0 \ + --info_file_save_path ./image_info_hpsv2.json \ + --save_result_path ./results_hpsv2 \ + --use_cache \ + --input_size 1024 1024 \ + --seed 42 \ + --infer_steps 100 +``` +参数说明: +- path:权重路径,包含clip_text_encoder、model、mt5、sdxl-vae-fp16-fix、tokenizer的权重及配置文件。 +- device_id:推理设备ID。 +- test_acc:使用 --test_acc 开启prompt_file列表中的图像生成,用于性能/精度测试。 +- prompt_file:用于图像生成的文字描述提示的列表文件路径。 +- prompt_file_type:prompt文件类型,用于指定读取方式,可选范围:plain,parti,hpsv2。默认值为plain。 +- max_num_prompts:限制prompt数量为前X个,0表示不限制。 +- info_file_save_path:生成图像信息的json文件路径。 +- save_result_path:生成图像的存放目录。 +- use_cache:使用 --use_cache 开启算法策略优化的测试。 +- input_size:生成的图像尺寸,宽高要求是8的倍数。 +- seed:设置随机种子,默认值为42。 +- infer_steps:推理迭代步数,默认值为100。 + +执行完成后在`./results_hpsv2`目录下生成推理图像。在当前目录下生成一个`image_info_hpsv2.json`文件,记录着图像和prompt的对应关系,并在终端显示推理时间。 + +### 4.5 计算精度指标 +1. CLIP-score +```bash +python clip_score.py \ + --device=cpu \ + --image_info="image_info_parti.json" \ + --model_name="ViT-H-14" \ + --model_weights_path="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" +``` +参数说明: +- device: 推理设备,默认为"cpu",如果是cuda设备可设置为"cuda"。 +- image_info: 上一步生成的`image_info_parti.json`文件。 +- model_name: Clip模型名称。 +- model_weights_path: Clip模型权重文件路径。 + +clip_score.py脚本可参考[SDXL](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/clip_score.py),执行完成后会在屏幕打印出精度计算结果。 + +2. HPSv2 +```bash +python hpsv2_score.py \ + --image_info="image_info_hpsv2.json" \ + --HPSv2_checkpoint="./HPS_v2_compressed.pt" \ + --clip_checkpoint="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" +``` +参数说明: +- image_info: 上一步生成的`image_info_hpsv2.json`文件。 +- HPSv2_checkpoint: HPSv2模型权重文件路径。 +- clip_checkpointh: Clip模型权重文件路径。 + +hpsv2_score.py脚本可参考[SDXL](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/hpsv2_score.py),执行完成后会在屏幕打印出精度计算结果。 + +## 五、模型推理性能结果参考 +### HunyuanDiT +| 硬件形态 | cpu规格 | batch size | 迭代次数 | 等价优化平均耗时 | 算法优化平均耗时 | +| :------: | :------: | :------: |:----:| :------: |:-----:| +| Atlas 800I A2(8*32G) | 64核(arm) | 1 | 100 | 43.404s | 29.208s | + +性能测试需要独占npu和cpu \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/__init__.py new file mode 100644 index 0000000000..3153d838c6 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .models import HunyuanDiT2DModel +from .pipeline import HunyuanDiTPipeline +from .schedulers import DDPMScheduler +from .utils import is_npu_available, postprocess_pil, set_seeds_generator, randn_tensor \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/__init__.py new file mode 100644 index 0000000000..ea86bb5c7e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/__init__.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .activation import get_activation_fn +from .attention import Attention +from .embedding import timestep_embedding, TimestepEmbedder, PatchEmbed, RotaryPositionEmbedding +from .mlp import Mlp +from .norm import get_normalization_helper +from .poolers import AttentionPool \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/activation.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/activation.py new file mode 100644 index 0000000000..521ce8dd4e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/activation.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch.nn as nn + + +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU(), + "gelu-approximate": nn.GELU(approximate="tanh") +} + + +def approx_gelu(): + return nn.GELU(approximate="tanh") + + +def get_activation_fn(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/attention.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/attention.py new file mode 100644 index 0000000000..0dec6e453f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/attention.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple +import math + +import torch +import torch.nn as nn +import torch_npu +from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding + +from .norm import get_normalization_helper + +EPS_DEFAULT = 1e-6 +EPS_FP16 = 1 / 65530 + + +def reshape_for_broadcast(x: torch.Tensor, freqs_cis: Tuple[torch.Tensor], head_first: bool = False): + ndim = x.ndim + if head_first: + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tuple[torch.Tensor], head_first: bool = False): + """ + x dtype: float16; cos/sin dtype: float32 + x_out dtype: float16 + """ + cos, sin = reshape_for_broadcast(x, freqs_cis, head_first) # [S, D] + x_out = npu_rotary_position_embedding(x.float(), cos, sin, 1).type_as(x) + return x_out + + +class Attention(nn.Module): + + def __init__(self, + hidden_size: int, + cross_attention_dim: int, + num_heads: int = 16, + attention_norm: str = None, + rope_type: str = "rope", + qkv_bias: bool = True): + super().__init__() + + self.is_cross_attention = cross_attention_dim is not None + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.rope_type = rope_type + if (self.rope_type != "rope" and self.rope_type != "atb"): + raise ValueError(f"The 'rope_type' must be 'rope' or 'atb', but got {self.rope_type}.") + + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + if not self.is_cross_attention: + self.kv_proj = nn.Linear(hidden_size, 2 * hidden_size, bias=qkv_bias) + else: + self.kv_proj = nn.Linear(cross_attention_dim, 2 * hidden_size, bias=qkv_bias) + + # If using fp16, eps should be 1 / 65530; else default 1e-6 + self.q_norm = get_normalization_helper(attention_norm, self.head_dim, eps=EPS_FP16) + self.k_norm = get_normalization_helper(attention_norm, self.head_dim, eps=EPS_FP16) + + self.out_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + + def forward(self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + freqs_cis_img: torch.Tensor = None, + layer: int = 0): + # hidden_states, encoder_hidden_states dtype: float16 + if hidden_states is None: + raise ValueError("Input hidden_states should not be none.") + if freqs_cis_img is None: + raise ValueError("Input freqs_cis_img should not be none.") + + # only support BNC now. + if hidden_states.ndim != 3: # 3: BNC + raise ValueError(f"The dimensions of hidden_states should be 3, but got {hidden_states.ndim}") + + batch_size = hidden_states.shape[0] + + query = self.q_proj(hidden_states) + query = query.reshape(batch_size, -1, self.num_heads, self.head_dim) + if not self.is_cross_attention: + kv = self.kv_proj(hidden_states) + else: + kv = self.kv_proj(encoder_hidden_states) + key, value = kv.reshape(batch_size, -1, 2, self.num_heads, self.head_dim).unbind(2) + # query, key, value dtype: float16 + + query = self.q_norm(query) + key = self.k_norm(key) + + # position embedding q and k, and flash attention + if not self.is_cross_attention: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if self.rope_type == "rope": + query = apply_rotary_emb(query, freqs_cis_img, head_first=True) + key = apply_rotary_emb(key, freqs_cis_img, head_first=True) + + hidden_states = torch_npu.npu_fusion_attention( + query, key, value, + head_num=self.num_heads, + input_layout="BNSD", + scale=1.0 / math.sqrt(self.head_dim), + )[0] + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + else: + if self.rope_type == "rope": + query = apply_rotary_emb(query, freqs_cis_img, head_first=False) + + hidden_states = torch_npu.npu_fusion_attention( + query, key, value, + head_num=self.num_heads, + input_layout="BSND", + scale=1.0 / math.sqrt(self.head_dim), + )[0] + hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) + + hidden_states = self.out_proj(hidden_states) + return hidden_states \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/embedding.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/embedding.py new file mode 100644 index 0000000000..5fc21579fc --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/embedding.py @@ -0,0 +1,713 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import functools +import math +from typing import Union, Tuple + +import torch +import torch.nn as nn +import numpy as np +import torch_npu + + +def get_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, rope_type: str = "adjacent"): + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + Args: + x (torch.Tensor): Query or key tensor to apply rotary embeddings. BSND or BNSD. + cos (torch.Tensor): Precomputed cos frequency tensor for complex exponentials. + sin (torch.Tensor): Precomputed sin frequency tensor for complex exponentials. + rope_type (str): + if "adjacent": rotate q to [-q_1, q_0, -q_3, q_2, ... , -q_d-1, q_d-2]. + Could to be used for HunyuanDiT, OpenSora, Flux, CogVideox. + if "symmetric": rotate q to [-q_d/2, -q_d/2+1, ... , -q_d-1, q_0, q_1, ... , q_d/2-1]. + Could to be used for OpenSoraPlan, Stable Audio. + if "symmetric_fuse": is equivalent to "symmetric" but has better performance in torch_npu. + + Returns: + (torch.Tensor): modified query or key tensor with rotary embeddings. + """ + if not isinstance(x, torch.Tensor): + raise ValueError(f"The type of input x must be torch.Tensor, but got {type(x)}.") + if not isinstance(cos, torch.Tensor): + raise ValueError(f"The type of input cos must be torch.Tensor, but got {type(cos)}.") + if not isinstance(sin, torch.Tensor): + raise ValueError(f"The type of input sin must be torch.Tensor, but got {type(sin)}.") + if not isinstance(rope_type, str): + raise ValueError(f"The type of input rope_type must be strings, but got {type(rope_type)}.") + + match rope_type: + case "adjacent": + # Used for HunyuanDiT, OpenSora, Flux, CogVideox + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + return (x * cos + x_rotated * sin) + case "symmetric": + # Used for OpenSoraPlan, Stable Audio + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + return (x * cos + x_rotated * sin) + case "symmetric_fuse": + return torch_npu.npu_rotary_mul(x, cos, sin) + case _: + raise ValueError(f"Unsupported rope_type: {rope_type}.") + + +def get_embedding_helper(embedding_type: str, embdding_dim: int): + match embedding_type: + case None: + return nn.Identity() + case 'rope': + return RotaryPositionEmbedding(embed_dim=embdding_dim) + case _: + raise ValueError(f"Unsupported embedding_type:{embedding_type}.") + + +def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线 + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None): + super().__init__() + if out_size is None: + out_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +def cal_1d_sincos_embed( + items: torch.Tensor, + embed_dim: int, + max_period: int = 10000, + step: int = 1, + flip: bool = False + ): + """ + Calculate 1d sinusoidal embeddings. + Args: + items (torch.Tensor): Items includes N indices. Must be a 1D tensor (N,). + embed_dim (int): The dimension of the embeddings. + max_period (int): Controls the minimum frequency of the embeddings. + step (int): The step of frequences. + flip (bool): If true, return [cos, cos, ..., sin, sin], else return [sin, sin ..., cos, cos]. + Return: + embed (torch.Tensor): An (N, embed_dim//step) tensor of item embeddings. + """ + + if not isinstance(embed_dim, int) or embed_dim <= 0: + raise ValueError(f"Embed_dim should be a positive integer, but receive {embed_dim}.") + if step not in [1, 2]: + raise ValueError(f"Step must be in [1, 2], but receive {step}.") + if embed_dim % (2 * step) != 0: + raise ValueError(f"Embed_dim must be divisible by {2 * step}, but receive {embed_dim}.") + + half_of_dim = embed_dim // 2 + # generate frequency vectors + freqs = torch.arange(start=0, end=half_of_dim, step=step, dtype=torch.float32, device=items.device) + freqs = torch.exp(-math.log(max_period) * freqs / half_of_dim) # (embed_dim//(2*step)) + # (N, 1) * (1, embed_dim//(2*step)) -> (N, embed_dim//(2*step)) + freqs = items[:, None].float() * freqs[None, :] + cos, sin = torch.cos(freqs), torch.sin(freqs) + # (N, embed_dim//step) + if flip: + embed = torch.cat([cos, sin], dim=-1) + else: + embed = torch.cat([sin, cos], dim=-1) + + return embed + + +class SinCosPositionEmbed1D(nn.Module): + def __init__( + self, + embed_dim: int, + step: int = 1, + flip: bool = False, + max_period: int = 10000, + cache1d: bool = True, + size: int = 128 + ): + """ + Create 1d sinusoidal embeddings. + Args: + embed_dim (int): The dimension of the embeddings. + step (int): The step of frequences. + flip (bool): If true, return [cos, cos, ..., sin, sin], else return [sin, sin ..., cos, cos]. + max_period (int): Controls the minimum frequency of the embeddings. + cache1d (bool): If true, use cache. + size (int): The size of cache. + """ + + super().__init__() + self.embed_dim = embed_dim + self.step = step + self.flip = flip + self.max_period = max_period + self.cache1d = cache1d + self.size = size + if self.cache1d: + items = torch.arange(self.size) + # (size, embed_dim//step) + embed = cal_1d_sincos_embed(items, self.embed_dim, self.max_period, self.step, self.flip) + self.register_buffer("embed", embed, persistent=False) + else: + self.embed = None + + def get_1d_sincos_embed(self, items: torch.Tensor): + """ + Calculate 1d sinusoidal embeddings. + Args: + items (torch.Tensor): Items includes N indices. Must be a 1D tensor (N,). + Return: + embed (torch.Tensor): An (N, embed_dim//step) tensor of item embeddings. + """ + + if len(items.shape) != 1: + raise ValueError(f"Items should be a 1D tensor, but receive a {len(items.shape)}D tensor.") + + items_max = torch.max(items) + dytpes = [torch.int, torch.long] + if self.cache1d and items_max < self.size and items.dtype in dytpes: + embed = self.embed[items] + else: + embed = cal_1d_sincos_embed(items, self.embed_dim, self.max_period, self.step, self.flip) + + return embed + + +class SinCosPositionEmbed2D(SinCosPositionEmbed1D): + def __init__( + self, + embed_dim: int = 256, + step: int = 1, + flip: bool = False, + max_period: int = 10000, + cache2d: bool = True, + grid_size: Union[Tuple[int, int], int] = (224, 224), + base_size: Union[int, None] = None, + interpolation_scale: float = 1.0, + persistent = False, + ): + """ + Create 2d sinusoidal embeddings. + Args: + embed_dim (int): The dimension of the embeddings. + step (int): The step of frequences. + flip (bool): If true, return [cos, cos, ..., sin, sin], else return [sin, sin ..., cos, cos]. + max_period (int): Controls the minimum frequency of the embeddings. + cache2d (bood): If true, use cache + grid_size (Tuple[int, int] or int): The size of grid. + base_size (int or None): The size of basic patches. + interpolation_scale (float): The scale parameter. + persistent (bool): If true, save the cache in dict. + """ + + self.embed_dim = embed_dim + self.step = step + self.flip = flip + self.max_period = max_period + self.cache2d = cache2d + self.interpolation_scale = interpolation_scale + + if isinstance(grid_size, int): + self.grid_size = (grid_size, grid_size) + else: + self.grid_size = grid_size + if base_size is None: + self.base_size = round((self.grid_size[0] * self.grid_size[1]) ** 0.5) + else: + self.base_size = base_size + + if not isinstance(self.embed_dim, int) or self.embed_dim <= 0: + raise ValueError(f"Embed_dim should be a positive integer, but receive {self.embed_dim}.") + if self.step not in [1, 2]: + raise ValueError(f"Step must be in [1, 2], but receive {self.step}.") + if self.embed_dim % (2 * self.step) != 0: + raise ValueError(f"Embed_dim must be divisible by {2 * self.step}, but receive {self.embed_dim}.") + + self.dim = self.embed_dim // (2 // self.step) + super().__init__(self.dim, self.step, self.flip, self.max_period, cache1d=False) + + if self.cache2d: + pos_embed = self._get_2d_sincos_embed(self.grid_size, self.base_size, self.interpolation_scale) + self.register_buffer("pos_embed", pos_embed, persistent=persistent) + else: + self.pos_embed = None + + def get_2d_sincos_embed(self, grid_size, base_size=None, interpolation_scale=1.0, device="cpu"): + """ + Initialize frequences. + Args: + grid_size (Tuple[int, int] or int): The size of grid. + base_size (int or None): The size of basic patches. + interpolation_scale (float): The scale parameter. + Return: + emb (torch.Tensor): An (1, H*W, embed_dim) tensor of embeddings. + """ + + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + is_shape_same = grid_size[0] == self.grid_size[0] and grid_size[1] == self.grid_size[1] \ + and base_size == self.base_size + if self.cache2d and is_shape_same and interpolation_scale == self.interpolation_scale: + embed = self.pos_embed + else: + embed = self._get_2d_sincos_embed(grid_size, base_size, interpolation_scale, device) + + return embed + + @functools.lru_cache(maxsize=512) + def _get_2d_sincos_embed(self, grid_size, base_size, interpolation_scale, device="cpu"): + """ + Initialize frequences. + Args: + grid_size (Tuple[int, int]): The size of grid. + base_size (int or None): The size of basic patches. + interpolation_scale (float): The scale parameter. + Return: + emb (torch.Tensor): An (H*W, embed_dim) tensor of embeddings. + """ + + grid_h = torch.arange(grid_size[0], dtype=torch.float32, device=device) / interpolation_scale + grid_w = torch.arange(grid_size[1], dtype=torch.float32, device=device) / interpolation_scale + + if base_size is not None: + grid_h *= base_size / grid_size[0] + grid_w *= base_size / grid_size[1] + + grid_h, grid_w = torch.meshgrid(grid_w, grid_h, indexing="ij") # here w goes first + grid = torch.stack([grid_h.t().reshape(-1), grid_w.t().reshape(-1)], dim=0) # (2, H*W) + emb_h = self.get_1d_sincos_embed(grid[0]) # (H*W, embed_dim//2) + emb_w = self.get_1d_sincos_embed(grid[1]) # (H*W, embed_dim//2) + emb = torch.cat([emb_h, emb_w], dim=-1) # (H*W, embed_dim) + return emb + + +class PatchEmbed(SinCosPositionEmbed2D): + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=1, + pos_embed_type="sincos", + pos_embed_max_size=None, # For SD3 cropping + ): + """ + 2D Image to Patch Embedding with support for position embedding. + Args: + height (int): Height of images. + width (int): Weight of images. + patch_size (int): The size of patches. + in_channels (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + layer_norm (bool): If true, use layernorm. + flatten (bool): If true, flatten the latent. + bias (bool): If true, use bias. + interpolation_scale: Scale coefficient. + pos_embed_type (str): The type of postion embedding. + pos_embed_max_size: The size of max postion embedding. + Adapted Models: SD3, HuanyuanDit + """ + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + self.pos_embed_max_size = pos_embed_max_size + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + self.interpolation_scale = interpolation_scale + + # Calculate positional embeddings based on max size or default + if pos_embed_max_size: + grid_size = pos_embed_max_size + else: + grid_size = int(num_patches**0.5) + + if pos_embed_type is None: + self.cache2d = False + elif pos_embed_type == "sincos": + self.cache2d = True + else: + raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") + + super().__init__( + embed_dim=embed_dim, + step=1, + cache2d=self.cache2d, + grid_size=grid_size, + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + persistent=True if pos_embed_max_size else False, + ) + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + def cropped_pos_embed(self, height, width): + """Crops positional embeddings for SD3 compatibility.""" + if self.pos_embed_max_size is None: + raise ValueError(f"Parameter:`pos_embed_max_size` must be set for cropping.") + + height = height // self.patch_size + width = width // self.patch_size + if height > self.pos_embed_max_size: + raise ValueError( + f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + if width > self.pos_embed_max_size: + raise ValueError( + f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + + top = (self.pos_embed_max_size - height) // 2 + left = (self.pos_embed_max_size - width) // 2 + spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) + spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + @property + def dtype(self): + return next(self.parameters()).dtype + + def forward(self, latent): + if self.pos_embed_max_size is not None: + height, width = latent.shape[-2:] + else: + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + + dtype_latent = latent.dtype + latent = self.proj(latent.to(self.dtype)) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + if self.pos_embed is None: + return latent.to(dtype_latent) + # Interpolate or crop positional embeddings as needed + if self.pos_embed_max_size: + pos_embed = self.cropped_pos_embed(height, width) + else: + pos_embed = self.get_2d_sincos_embed( + (height, width), + self.base_size, + interpolation_scale=self.interpolation_scale, + device=latent.device + ).unsqueeze(0) + + return (latent + pos_embed).to(dtype_latent) + + +class RotaryCosSinEmbed: + """ + RotaryCosSinEmbed get cos_sin tables of rope. + """ + def __init__( + self, + embed_dim: int, + use_real: bool = True, + repeat_interleave_real: bool = True, + theta: float = 10000.0, + linear_factor: float = 1.0, + ntk_factor: float = 1.0, + freqs_dtype = torch.float32, + ): + """ + Args: + embed_dim (int): The embedding dimension size. + use_real (bool): If `True`, return real part and imaginary part separately. Otherwise, return complex numbers. + repeat_interleave_real (bool): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + linear_factor (float): Scaling factor for the context extrapolation. Defaults to 1.0. Use for `lumina`. + ntk_factor (float): Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. Use for `lumina`. + freqs_dtype: Defaults to torch.float32. Only be torch.float64 for Flux. + """ + super().__init__() + + self.embed_dim = embed_dim + self.use_real = use_real + self.repeat_interleave_real = repeat_interleave_real + self.theta = theta + self.linear_factor = linear_factor # Use for lumina. + self.ntk_factor = ntk_factor # Use for lumina. + self.freqs_dtype = freqs_dtype # Flux: torch.float64 + + + def get_resize_crop_region_for_grid(self, src_h: int, src_w: int, base_size: int): + """ + Get resize and crop region for grid. + + Args: + src_h (int): The grid height of the positional embedding. + src_w (int): The grid width of the positional embedding. + base_size (int): The target size of resizing and cropping region for grid. + + Returns: + Tuple[int]: The top-left and bottom-right coordinates of the crop. + """ + if not isinstance(src_h, int): + raise ValueError(f"The type of input src_h must be int, but got {type(src_h)}.") + if not isinstance(src_w, int): + raise ValueError(f"The type of input src_w must be int, but got {type(src_w)}.") + if not isinstance(base_size, int): + raise ValueError(f"The type of input base_size must be int, but got {type(base_size)}.") + if src_h <= 0: + raise ValueError(f"Input src_h must be greater than 0, but got {src_h}.") + if src_w <= 0: + raise ValueError(f"Input src_w must be greater than 0, but got {src_w}.") + if base_size <= 0: + raise ValueError(f"Input base_size must be greater than 0, but got {base_size}.") + + ratio = src_h / src_w + # resize + if ratio > 1: + resize_height = base_size + resize_width = int(round(base_size / src_h * src_w)) + else: + resize_width = base_size + resize_height = int(round(base_size / src_w * src_h)) + crop_top = int(round((base_size - resize_height) / 2.0)) + crop_left = int(round((base_size - resize_width) / 2.0)) + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + + def get_1d_rotary_pos_embed(self, pos: Union[np.ndarray, int]) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + Args: + pos (np.ndarray or int): Position indices for the frequency tensor. [S] or scalar. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]. + """ + if isinstance(pos, int): + pos = torch.arange(pos) + elif isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + else: + raise ValueError(f"The type of input pos must be np.ndarray or int, but got {type(pos)}.") + + half_of_dim = self.embed_dim // 2 + + theta = self.theta * self.ntk_factor + freqs = torch.arange(start=0, end=half_of_dim, step=2, dtype=self.freqs_dtype, device=pos.device) # [D/4] + freqs = (1.0 / (theta ** (freqs[: (half_of_dim // 2)] / half_of_dim)) / self.linear_factor) # [D/4] + freqs = torch.outer(pos, freqs) # [S, D/4] + + if self.use_real and self.repeat_interleave_real: + # HunyuanDiT, Flux, CogVideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D/2] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D/2] + return freqs_cos, freqs_sin + elif self.use_real: + # Stable Audio, Allegro + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D/2] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D/2] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/4] + return freqs_cis + + + def get_2d_rotary_pos_embed(self, grid_h: int, grid_w: int, base_size: int): + """ + RoPE for image tokens with 2d structure. + + Args: + grid_h (int): The grid height of the positional embedding. + grid_w (int): The grid width of the positional embedding. + base_size (int): The target size of resizing and cropping region for grid. + + Returns: + torch.Tensor: positional embedding with shape (grid_size * grid_size, embed_dim/2). + """ + if not isinstance(grid_h, int): + raise ValueError(f"The type of input grid_h must be int, but got {type(grid_h)}.") + if not isinstance(grid_w, int): + raise ValueError(f"The type of input grid_w must be int, but got {type(grid_w)}.") + if not isinstance(base_size, int): + raise ValueError(f"The type of input base_size must be int, but got {type(base_size)}.") + if grid_h <= 0: + raise ValueError(f"Input grid_h must be greater than 0, but got {grid_h}.") + if grid_w <= 0: + raise ValueError(f"Input grid_w must be greater than 0, but got {grid_w}.") + if base_size <= 0: + raise ValueError(f"Input base_size must be greater than 0, but got {base_size}.") + + start, stop = self.get_resize_crop_region_for_grid(grid_h, grid_w, base_size) + grid_h = np.linspace(start[0], stop[0], grid_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_w, endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + # use half of dimensions to encode grid_h and grid_w + emb_h = self.get_1d_rotary_pos_embed(grid[0].reshape(-1)) # (H*W, D/2) if use_real else (H*W, D/4) + emb_w = self.get_1d_rotary_pos_embed(grid[1].reshape(-1)) # (H*W, D/2) if use_real else (H*W, D/4) + + if self.use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) + pos_embed = (cos, sin) + else: + pos_embed = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + + return pos_embed + + +class RotaryPositionEmbedding(RotaryCosSinEmbed, nn.Module): + """ + RotaryPositionEmbedding apply rotary embeddings to input tensors using the given frequency tensor. + """ + def __init__( + self, + embed_dim: int, + grid_h: int = 64, + grid_w: int = 64, + base_size: int = 32, + rope_type: str = "adjacent", + use_real: bool = True, + repeat_interleave_real: bool = True, + theta: float = 10000.0, + linear_factor: float = 1.0, + ntk_factor: float = 1.0, + ): + """ + Args: + embed_dim (int): The embedding dimension size. + grid_h (int): The grid height of the positional embedding. + grid_w (int): The grid width of the positional embedding. + base_size (int): The target size of resizing and cropping region for grid. + rope_type (str): + if "adjacent": rotate q to [-q_1, q_0, -q_3, q_2, ... , -q_d-1, q_d-2]. + Could to be used for HunyuanDiT, Flux, CogVideox. + if "symmetric": rotate q to [-q_d/2, -q_d/2+1, ... , -q_d-1, q_0, q_1, ... , q_d/2-1]. + Could to be used for Stable Audio. + if "symmetric-npu": is equivalent to "symmetric" but has better performance in torch_npu. + use_real (bool): If `True`, return real part and imaginary part separately. Otherwise, return complex numbers. + repeat_interleave_real (bool): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + linear_factor (float): Scaling factor for the context extrapolation. Defaults to 1.0. Use for `lumina`. + ntk_factor (float): Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. Use for `lumina`. + """ + # check inputs + if embed_dim % 4 != 0 or embed_dim <= 2: + raise ValueError(f"Input embed_dim must be divisible by 4 and greater than 2, but got {embed_dim}.") + if grid_h <= 0 or grid_w <= 0: + raise ValueError(f"Input grid_size must be greater than 0, but got ({grid_h}, {grid_w}).") + if base_size <= 0: + raise ValueError(f"Input base_size must be greater than 0, but got {base_size}.") + if theta <= 0.: + raise ValueError(f"Input theta must be greater than 0, but got {theta}.") + if linear_factor <= 0.: + raise ValueError(f"Input linear_factor must be greater than 0, but got {linear_factor}.") + if ntk_factor <= 0.: + raise ValueError(f"Input ntk_factor must be greater than 0, but got {ntk_factor}.") + + self.rope_type = rope_type + self.use_real = use_real + super().__init__(embed_dim, use_real, repeat_interleave_real, theta, linear_factor, ntk_factor) + + self.freqs_cis_img = self.get_2d_rotary_pos_embed(grid_h, grid_w, base_size) + + + def forward(self, x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None): + """ + The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting + compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): Query or key tensor to apply rotary embeddings. [B, H, S, D]. + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + freqs_cis = freqs_cis if freqs_cis is not None else self.freqs_cis_img + + if self.use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None].to(x.dtype) + sin = sin[None, None].to(x.dtype) + cos, sin = cos.to(x.device), sin.to(x.device) + + x_out = get_rotary_emb(x, cos, sin, self.rope_type) + return x_out + + else: + # used for lumina + x_rotated = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + return x_out.type_as(x) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/mlp.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/mlp.py new file mode 100644 index 0000000000..01193759c9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/mlp.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections.abc +from itertools import repeat +from functools import partial + +import torch.nn as nn +from .activation import get_activation_fn + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + + def __init__( + self, + features_in, + features_hidden=None, + features_out=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + use_conv=False, + ): + super().__init__() + features_out = features_out or features_in + features_hidden = features_hidden or features_in + to_2tuple = self._ntuple(2) + bias = to_2tuple(bias) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(features_in, features_hidden, bias=bias[0]) + self.act = act_layer() if not isinstance(act_layer, str) else get_activation_fn(act_layer) + self.norm = norm_layer(features_hidden) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(features_hidden, features_out, bias=bias[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.fc2(x) + return x + + def _ntuple(self, n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/norm.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/norm.py new file mode 100644 index 0000000000..e35eae02f2 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/norm.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +import torch_npu + +from ..utils import is_npu_available + + +def get_normalization_helper(norm_type: str, norm_dim: int, eps: float = 1e-5): + match norm_type: + case None: + return nn.Identity() + case 'layer_norm': + return nn.LayerNorm(norm_dim, eps=eps) + case 'llama_rms_norm': + return LlamaRMSNorm(norm_dim, eps=eps) + case _: + raise ValueError(f"Unsupported norm_type:{norm_type}.") + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + if is_npu_available(): + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/poolers.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/poolers.py new file mode 100644 index 0000000000..c73d2cc0d2 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/layers/poolers.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AttentionPool(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/__init__.py new file mode 100644 index 0000000000..a86f4f327c --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .hydit import HunyuanDiT2DModel \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/hydit.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/hydit.py new file mode 100644 index 0000000000..de4a06fc27 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/hydit.py @@ -0,0 +1,407 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple, List + +import torch +import torch.nn as nn + +from mindiesd import ConfigMixin +from .model_utils import DiffusionModel +from ..layers import get_activation_fn, get_normalization_helper, timestep_embedding +from ..layers import Mlp, PatchEmbed, TimestepEmbedder, Attention, AttentionPool + + +class HunyuanDiTBlock(nn.Module): + """ + A HunYuanDiT block with `add` conditioning. + """ + def __init__(self, + hidden_size, + c_emb_size, + num_heads, + mlp_ratio=4.0, + text_states_dim=1024, + skip=False, + ): + super().__init__() + + norm_type = "layer_norm" + + # ========================= Self-Attention ========================= + self.norm1 = get_normalization_helper(norm_type, hidden_size, eps=1e-6) + self.attn1 = Attention(hidden_size=hidden_size, + cross_attention_dim=None, + num_heads=num_heads, + attention_norm=norm_type, + rope_type="rope") + + # ========================= FFN ========================= + self.norm2 = get_normalization_helper(norm_type, hidden_size, eps=1e-6) + self.mlp = Mlp( + features_in=hidden_size, features_hidden=int(hidden_size * mlp_ratio), act_layer="gelu-approximate") + + # ========================= Add ========================= + # Simply use add like SDXL. + self.default_modulation = nn.Sequential( + get_activation_fn("silu"), + nn.Linear(c_emb_size, hidden_size, bias=True) + ) + + # ========================= Cross-Attention ========================= + self.attn2 = Attention(hidden_size=hidden_size, + cross_attention_dim=text_states_dim, + num_heads=num_heads, + attention_norm=norm_type, + rope_type="rope") + self.norm3 = get_normalization_helper(norm_type, hidden_size, eps=1e-6) + + # ========================= Skip Connection ========================= + if skip: + self.skip_norm = get_normalization_helper(norm_type, 2 * hidden_size, eps=1e-6) + self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) + else: + self.skip_linear = None + + + def forward(self, x, tensor_input, skip=None, layer=0): + c, text_states, freqs_cis_img = tensor_input + # Long Skip Connection + if self.skip_linear is not None: + cat = torch.cat([x, skip], dim=-1) + cat = self.skip_norm(cat) + x = self.skip_linear(cat) + # Self-Attention + shift_msa = self.default_modulation(c).unsqueeze(dim=1) + x = x + self.attn1(hidden_states=self.norm1(x) + shift_msa, + freqs_cis_img=freqs_cis_img, + layer=layer) + # Cross-Attention + x = x + self.attn2(hidden_states=self.norm3(x), + encoder_hidden_states=text_states, + freqs_cis_img=freqs_cis_img, + layer=layer) + # FFN Layer + mlp_inputs = self.norm2(x) + x = x + self.mlp(mlp_inputs) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of HunYuanDiT. + """ + def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + get_activation_fn("silu"), + nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True) + ) + + @staticmethod + def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = self.modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class HunyuanDiTConfig(ConfigMixin): + config_name = 'config.json' + + def __init__( + self, + input_size: Tuple[int, int] = (None, None), + patch_size: int = 2, + in_channels: int = 4, + hidden_size: int = 1152, + depth: int = 28, + num_heads: int = 16, + mlp_ratio: float = 4.0, + text_states_dim: int = 1024, + text_states_dim_t5: int = 2048, + text_len: int = 77, + text_len_t5: int = 256, + size_cond: List = None, + use_style_cond: bool = False, + ) -> None: + super().__init__() + + self.input_size = input_size + self.patch_size = patch_size + self.in_channels = in_channels + self.hidden_size = hidden_size + self.depth = depth + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.text_states_dim = text_states_dim + self.text_states_dim_t5 = text_states_dim_t5 + self.text_len = text_len + self.text_len_t5 = text_len_t5 + self.size_cond = size_cond + self.use_style_cond = use_style_cond + + +class HunyuanDiT2DModel(DiffusionModel): + + config_class = HunyuanDiTConfig + weigths_name = "pytorch_model_ema.pt" + + def __init__(self, config): + super().__init__(config) + self.config = config + self._check_config_params() + + # learn_sigma is True + self.out_channels = self.config.in_channels * 2 + + self.mlp_t5 = Mlp(features_in=self.config.text_states_dim_t5, + features_hidden=self.config.text_states_dim_t5 * 4, + features_out=self.config.text_states_dim, + act_layer="silu", + bias=True) + + # learnable replace + self.text_embedding_padding = nn.Parameter( + torch.randn(self.config.text_len + self.config.text_len_t5, + self.config.text_states_dim, + dtype=torch.float32)) + + # Attention pooling + pooler_out_dim = 1024 + self.pooler = AttentionPool(self.config.text_len_t5, + self.config.text_states_dim_t5, + num_heads=8, + output_dim=pooler_out_dim) + + # Dimension of the extra input vectors + self.extra_in_dim = pooler_out_dim + + # Only for hydit <= 1.1 + if self.config.size_cond: + # Image size and crop size conditions + self.extra_in_dim += 6 * 256 + if self.config.use_style_cond: + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding(1, self.config.hidden_size) + self.extra_in_dim += self.config.hidden_size + + # Text embedding for `add` + height = self.config.input_size[0] // 8 + width = self.config.input_size[1] // 8 + self.x_embedder = PatchEmbed(height, + width, + self.config.patch_size, + self.config.in_channels, + self.config.hidden_size, + pos_embed_type=None) + + self.t_embedder = TimestepEmbedder(self.config.hidden_size) + self.extra_embedder = Mlp(features_in=self.extra_in_dim, + features_hidden=self.config.hidden_size * 4, + features_out=self.config.hidden_size, + act_layer="silu", + bias=True) + + # HUnYuanDiT Blocks + self.blocks = nn.ModuleList([ + HunyuanDiTBlock(hidden_size=self.config.hidden_size, + c_emb_size=self.config.hidden_size, + num_heads=self.config.num_heads, + mlp_ratio=self.config.mlp_ratio, + text_states_dim=self.config.text_states_dim, + skip=layer > self.config.depth // 2) + for layer in range(self.config.depth) + ]) + self.final_layer = FinalLayer(self.config.hidden_size, + self.config.hidden_size, + self.config.patch_size, + self.out_channels) + self.unpatchify_channels = self.out_channels + + + def forward(self, + tensor_input=None, + use_cache: bool = False, + cache_params=None, + if_skip: int = 0): + + x, t, encoder_hidden_states, embeds_and_mask_input, freqs_cis_img = tensor_input + if use_cache: + block_start, num_blocks, delta_cache = cache_params + + text_embedding_mask, encoder_hidden_states_t5, text_embedding_mask_t5, image_meta_size, style = \ + embeds_and_mask_input + text_states = encoder_hidden_states + text_states_t5 = encoder_hidden_states_t5 + text_states_mask = text_embedding_mask.bool() + text_states_t5_mask = text_embedding_mask_t5.bool() + b_t5, l_t5, c_t5 = text_states_t5.shape + text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)) + text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024 + clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1) + + clip_t5_mask = clip_t5_mask + text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states)) + + # The input x shape is [2, 4, 128, 128] + height, width = x.shape[-2:] + th, tw = height // self.config.patch_size, width // self.config.patch_size + + # Build time and image embedding + t = self.t_embedder(t) + x = self.x_embedder(x) + # The x shape after x_embedder is [2, 4096, 1408] + + # Build text tokens with pooling + extra_vec = self.pooler(encoder_hidden_states_t5) + + # Only for hydit <= 1.1 + if image_meta_size is not None: + image_meta_size = timestep_embedding(image_meta_size.half().view(-1), 256) # [B * 6, 256] + image_meta_size = image_meta_size.half().view(-1, 6 * 256) + extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256] + if style is not None: + style_embedding = self.style_embedder(style) + extra_vec = torch.cat([extra_vec, style_embedding], dim=1) + + # Concatenate all extra vectors + c = t + self.extra_embedder(extra_vec) # [B, D] + + # Forward pass through HunYuanDiT blocks + tensor_input = (c, text_states, freqs_cis_img) + if not use_cache: + skips = [] + for layer, block in enumerate(self.blocks): + if layer > self.config.depth // 2: + skip = skips.pop() + x = block(x, tensor_input, skip=skip, layer=layer) # (N, L, D) + else: + x = block(x, tensor_input, skip=None, layer=layer) # (N, L, D) + + if layer < (self.config.depth // 2 - 1): + skips.append(x) + else: + cache_params = (use_cache, if_skip, block_start, num_blocks) + x, delta_cache = self._forward_blocks(x, tensor_input, cache_params, delta_cache) + + # Final layer + x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels) + x = self._unpatchify(x, th, tw) + + if use_cache: + return x, delta_cache + + return x + + + def _forward_blocks_range(self, x, tensor_input, skips, start_idx, end_idx): + for layer, block in zip(range(start_idx, end_idx), self.blocks[start_idx : end_idx]): + if layer > self.config.depth // 2: + skip = skips.pop() + x = block(x, tensor_input, skip=skip, layer=layer) # (N, L, D) + else: + x = block(x, tensor_input, skip=None, layer=layer) # (N, L, D) + + if layer < (self.config.depth // 2 - 1): + skips.append(x) + return x, skips + + + def _forward_blocks(self, x, tensor_input, cache_params, delta_cache): + use_cache, if_skip, block_start, num_blocks = cache_params + skips = [] + if not use_cache: + x, skips = self._forward_blocks_range(x, tensor_input, skips, 0, len(self.blocks)) + else: + x, skips = self._forward_blocks_range(x, tensor_input, skips, 0, block_start) + + cache_end = block_start + num_blocks + x_before_cache = x.clone() + if not if_skip: + x, skips = self._forward_blocks_range(x, tensor_input, skips, block_start, cache_end) + delta_cache = x - x_before_cache + else: + x = x_before_cache + delta_cache + + x, skips = self._forward_blocks_range(x, tensor_input, skips, cache_end, len(self.blocks)) + return x, delta_cache + + + def _load_weights(self, state_dict): + weights = state_dict + + weights['mlp_t5.fc1.weight'] = weights.pop('mlp_t5.0.weight') + weights['mlp_t5.fc1.bias'] = weights.pop('mlp_t5.0.bias') + weights['mlp_t5.fc2.weight'] = weights.pop('mlp_t5.2.weight') + weights['mlp_t5.fc2.bias'] = weights.pop('mlp_t5.2.bias') + + weights['extra_embedder.fc1.weight'] = weights.pop('extra_embedder.0.weight') + weights['extra_embedder.fc1.bias'] = weights.pop('extra_embedder.0.bias') + weights['extra_embedder.fc2.weight'] = weights.pop('extra_embedder.2.weight') + weights['extra_embedder.fc2.bias'] = weights.pop('extra_embedder.2.bias') + + for i in range(self.config.depth): + prefix_key = 'blocks.' + str(i) + '.' + + qkv_proj_weights = weights.pop(prefix_key + 'attn1.Wqkv.weight') + qkv_proj_bias = weights.pop(prefix_key + 'attn1.Wqkv.bias') + to_q_weights, to_k_weights, to_v_weights = torch.chunk(qkv_proj_weights, 3, dim=0) + to_q_bias, to_k_bias, to_v_bias = torch.chunk(qkv_proj_bias, 3, dim=0) + weights[prefix_key + 'attn1.q_proj.weight'] = to_q_weights + weights[prefix_key + 'attn1.q_proj.bias'] = to_q_bias + weights[prefix_key + 'attn1.kv_proj.weight'] = torch.cat([to_k_weights, to_v_weights], dim=0) + weights[prefix_key + 'attn1.kv_proj.bias'] = torch.cat([to_k_bias, to_v_bias], dim=0) + + self.load_state_dict(weights) + + + def _unpatchify(self, x, h, w): + c = self.unpatchify_channels + p = self.config.patch_size + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + + + def _check_config_params(self): + params_checks = { + "patch_size": int, + "in_channels": int, + "hidden_size": int, + "depth": int, + "num_heads": int, + "mlp_ratio": float, + "text_states_dim": int, + "text_states_dim_t5": int, + "text_len": int, + "text_len_t5": int + } + for attr, expected_type in params_checks.items(): + if hasattr(self.config, attr) and not isinstance(getattr(self.config, attr), expected_type): + raise TypeError(f"The type of {attr} in config must be {expected_type.name}, but got {type(attr)}.") + if getattr(self.config, attr) < 0: + raise ValueError(f"The {attr} in config must be greater than 0, but got {attr}.") + if self.config.hidden_size < self.config.num_heads: + raise ValueError(f"The hidden_size must be greater than num_heads.") \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/model_load_utils.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/model_load_utils.py new file mode 100644 index 0000000000..11dfa959e9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/model_load_utils.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + + +import os +import torch +import safetensors.torch + + +SAFETENSORS_EXTENSION = "safetensors" +EMA_STATE_DICT = "ema_state_dict" +STATE_DICT = "state_dict" +CPU = "cpu" + + +def load_state_dict(model_path): + name = os.path.basename(model_path).split('.')[-1] # get weights name + if name.endswith("ckpt"): + weight = torch.load(model_path, map_location=CPU) + if (EMA_STATE_DICT in weight): + weight = weight[EMA_STATE_DICT] + weight = {key.replace("module.", ""): value for key, value in weight.items()} + elif STATE_DICT in weight: + weight = weight[STATE_DICT] + return weight + elif name == SAFETENSORS_EXTENSION: # diffuser model use same name + return safetensors.torch.load_file(model_path, device=CPU) # first load on cpu + else: + # to support hf shard model weights + return torch.load(model_path, map_location=CPU) # first load on cpu \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/model_utils.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/model_utils.py new file mode 100644 index 0000000000..b083911b74 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/models/model_utils.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + + +import os + +import torch +import torch.nn as nn +from mindiesd import ConfigMixin +from .model_load_utils import load_state_dict + + +DIFFUSER_SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" +WEIGHTS_NAME = "diffusion_pytorch_model.bin" + + +class DiffusionModel(nn.Module): + config_class = ConfigMixin + weigths_name = DIFFUSER_SAFETENSORS_WEIGHTS_NAME + + def __init__(self, config): + super().__init__() + self.config = config + + @classmethod + def from_pretrained(cls, model_path, **kwargs): + dtype = kwargs.pop('dtype', None) # get dtype from kwargs + if not (dtype in {torch.bfloat16, torch.float16}): + raise ValueError("dtype should be a torch.bfloat16 or torch.float16") + + # 1. check model_path and weights_path + real_path = os.path.abspath(model_path) + if not (os.path.exists(real_path) and os.path.isdir(real_path)): + raise ValueError(f"{real_path} is invalid!") + + if not issubclass(cls.config_class, ConfigMixin): + raise ValueError("config_class is not subclass of ConfigMixin.") + + if cls.weigths_name is None: + raise ValueError("weigths_name is not defined.") + + weights_path = os.path.join(real_path, cls.weigths_name) + if not (os.path.exists(weights_path) and os.path.isfile(weights_path)): + raise ValueError(f"'{cls.weigths_name}' is not found in '{model_path}'!") + + # 2. load config_class from json + init_dict, _ = cls.config_class.load_config(real_path, **kwargs) + config = cls.config_class(**init_dict) + + # 3. init model with config + model = cls(config) + + # 4. load model weights + state_dict = load_state_dict(weights_path) + model._load_weights(state_dict) + + # 5. model to dtype + if dtype is not None: + model.to(dtype) + return model + + def _load_weights(self, state_dict): + with torch.no_grad(): + self.load_state_dict(state_dict) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/__init__.py new file mode 100644 index 0000000000..27abdff09b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .hydit_pipeline import HunyuanDiTPipeline \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/hydit_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/hydit_pipeline.py new file mode 100644 index 0000000000..8fe68de323 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/hydit_pipeline.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Union, Tuple +import logging + +import torch +from tqdm import tqdm +import numpy as np + +from ..layers import RotaryPositionEmbedding +from ..utils import postprocess_pil, randn_tensor + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +TOKENIZER_MAX_LENGTH = 256 +MAX_PROMPT_LENGTH = 1024 +NEGATIVE_PROMPT = '错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,' +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +class HunyuanDiTPipeline: + + def __init__( + self, + scheduler, + text_encoder, + tokenizer, + text_encoder_2, + tokenizer_2, + transformer, + vae, + args, + input_size: Tuple[int, int] = (1024, 1024) + ): + super().__init__() + torch.set_grad_enabled(False) + + self.scheduler = scheduler + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.text_encoder_2 = text_encoder_2 + self.tokenizer_2 = tokenizer_2 + self.transformer = transformer + self.vae = vae + self.input_size = input_size + self._check_init_input() + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.hidden_states_batch = 2 + self.device = torch.device("npu") + self.guidance_scale = args.guidance_scale + + # Set image height and width. + height = self.input_size[0] + width = self.input_size[1] + self.height = int((height // 16) * 16) + self.width = int((width // 16) * 16) + if (self.height, self.width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(self.width, self.height) + self.height = int(height) + self.width = int(width) + logger.warning(f"Reshaped to ({self.height}, {self.width}), Supported shapes are {SUPPORTED_SHAPE}") + + # Create image rotary position embedding + self.rotary_pos_emb = self._get_rotary_pos_emb() + + # Only for hydit <= 1.1 + self.image_meta_size, self.style = self._get_v1_params(args) + + # Use DiT Cache + self.use_cache = args.use_cache + if self.use_cache: + self.step_start = args.step_start + self.step_interval = args.step_interval + self.block_start = args.block_start + self.num_blocks = args.num_blocks + self.step_contrast = 9 % 2 + self.skip_flag_true = torch.ones([1], dtype=torch.int64).to(self.device) + self.skip_flag_false = torch.zeros([1], dtype=torch.int64).to(self.device) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + num_inference_steps: int = 100, + seed_generator: torch.Generator = None + ): + # 1. Check inputs. Raise error if not correct + check_call_input(prompt, num_images_per_prompt, num_inference_steps, seed_generator) + + # 2. Define prompt and negative_prompt + if prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = 1 + negative_prompt = NEGATIVE_PROMPT + if prompt is not None and not isinstance(prompt, type(negative_prompt)): + raise ValueError( + f"negative_prompt should be the same type to prompt, " + f"but got {type(negative_prompt)} != {type(prompt)}." + ) + prompt_info = (prompt, negative_prompt, num_images_per_prompt) + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask = \ + self._encode_prompt(prompt_info, batch_size, embedder_t5=False) + prompt_embeds_t5, negative_prompt_embeds_t5, attention_mask_t5, uncond_attention_mask_t5 = \ + self._encode_prompt(prompt_info, batch_size, embedder_t5=True) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_mask = torch.cat([uncond_attention_mask, attention_mask]) + prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5]) + attention_mask_t5 = torch.cat([uncond_attention_mask_t5, attention_mask_t5]) + transformer_input = (attention_mask, prompt_embeds_t5, attention_mask_t5, self.image_meta_size, self.style) + torch.npu.empty_cache() + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps = self.scheduler.timesteps + step = (timesteps, num_inference_steps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + shape = (batch_size * num_images_per_prompt, + num_channels_latents, + self.height // self.vae_scale_factor, + self.width // self.vae_scale_factor) + latents = randn_tensor(shape, generator=seed_generator, device=self.device, dtype=prompt_embeds.dtype) * 1.0 + + # 6. Denoising loop + latents = self._sampling(latents, step, prompt_embeds, transformer_input, seed_generator) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = postprocess_pil(image) + + return (image, None) + + + def _check_init_input(self): + if not isinstance(self.input_size, tuple): + raise ValueError(f"The type of input_size must be tuple, but got {type(self.input_size)}.") + if len(self.input_size) != 2: + raise ValueError(f"The length of input_size must be 2, but got {len(self.input_size)}.") + if self.input_size[0] % 8 != 0 or self.input_size[0] <= 0: + raise ValueError( + f"The height of input_size must be divisible by 8 and greater than 0, but got {self.input_size[0]}.") + if self.input_size[1] % 8 != 0 or self.input_size[1] <= 0: + raise ValueError( + f"The width of input_size must be divisible by 8 and greater than 0, but got {self.input_size[1]}.") + + + def _get_rotary_pos_emb(self): + grid_height = self.height // 8 // self.transformer.config.patch_size + grid_width = self.width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + head_dim = self.transformer.config.hidden_size // self.transformer.config.num_heads + + rope = RotaryPositionEmbedding(head_dim) + freqs_cis_img = rope.get_2d_rotary_pos_embed(grid_height, grid_width, base_size) + if isinstance(freqs_cis_img, tuple) and len(freqs_cis_img) == 2: + return (freqs_cis_img[0].to(self.device), freqs_cis_img[1].to(self.device)) + else: + raise ValueError(f"The type of rotary_pos_emb must be tuple and the length must be 2.") + + + def _get_v1_params(self, args): + if args.use_style_cond and args.size_cond is not None: + src_size_cond = args.size_cond + if isinstance(src_size_cond, int): + src_size_cond = [src_size_cond, src_size_cond] + if not isinstance(src_size_cond, (list, tuple)): + raise TypeError(f"The src_size_cond must be a list or tuple, but got {type(src_size_cond)}.") + if len(src_size_cond) != 2: + raise ValueError(f"The src_size_cond must be a tuple of 2 integers, but got {len(src_size_cond)}.") + size_cond = list(src_size_cond) + [self.width, self.height, 0, 0] + image_meta_size = torch.as_tensor([size_cond] * 2 * args.batch_size, device=args.device) + style = torch.as_tensor([0, 0] * args.batch_size, device=args.device) + else: + image_meta_size = None + style = None + return image_meta_size, style + + + def _encode_prompt(self, prompt_info, batch_size, embedder_t5=False): + if not embedder_t5: + text_encoder = self.text_encoder + tokenizer = self.tokenizer + max_length = self.tokenizer.model_max_length + else: + text_encoder = self.text_encoder_2 + tokenizer = self.tokenizer_2 + max_length = TOKENIZER_MAX_LENGTH + + prompt, negative_prompt, num_images_per_prompt = prompt_info + # prompt_embeds + prompt_embeds, attention_mask = self._encode_embeds( + prompt, tokenizer, text_encoder, max_length, num_images_per_prompt) + if text_encoder is not None: + prompt_embeds_dtype = text_encoder.dtype + elif self.transformer is not None: + prompt_embeds_dtype = self.transformer.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=self.device) + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + # negative_prompt_embeds + negative_prompt_embeds, uncond_attention_mask = self._encode_negative_embeds( + negative_prompt, tokenizer, text_encoder, prompt_embeds, num_images_per_prompt) + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=self.device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask + + + def _encode_embeds(self, prompt, tokenizer, text_encoder, max_length, num_images_per_prompt): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + attention_mask = text_inputs.attention_mask.to(self.device) + prompt_embeds = text_encoder( + text_input_ids.to(self.device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + attention_mask = attention_mask.repeat(num_images_per_prompt, 1) + + return prompt_embeds, attention_mask + + + def _encode_negative_embeds(self, negative_prompt, tokenizer, text_encoder, prompt_embeds, num_images_per_prompt): + uncond_tokens: List[str] + if isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_attention_mask = uncond_input.attention_mask.to(self.device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(self.device), + attention_mask=uncond_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + uncond_attention_mask = uncond_attention_mask.repeat(num_images_per_prompt, 1) + + return negative_prompt_embeds, uncond_attention_mask + + + def _sampling(self, latents, step, prompt_embeds, transformer_input, seed_generator): + + timesteps, num_inference_steps = step + + if self.use_cache: + delta_cache = torch.zeros([2, 3840, 1408], dtype=torch.float16).to(self.device) + step_start = self.step_start + + num_warmup_steps = len(timesteps) - num_inference_steps + with self._progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * self.hidden_states_batch) + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=latent_model_input.device) + + # if use_fp16 + latent_model_input = latent_model_input.half() + t_expand = t_expand.half() + prompt_embeds = prompt_embeds.half() + + # predict the noise residual + tensor_input = (latent_model_input, t_expand, prompt_embeds, transformer_input, self.rotary_pos_emb) + if not self.use_cache: + noise_pred = self.transformer(tensor_input) + else: + cache_params = (self.block_start, self.num_blocks, delta_cache.half()) + inputs = [tensor_input, self.use_cache, cache_params, self.skip_flag_false] + if i < step_start: + noise_pred, delta_cache = self.transformer(*inputs) + else: + if i % self.step_interval == self.step_contrast: + noise_pred, delta_cache = self.transformer(*inputs) + else: + inputs[-1] = self.skip_flag_true + noise_pred, delta_cache = self.transformer(*inputs) + + # if learn_sigma + noise_pred, _ = noise_pred.chunk(2, dim=1) + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, seed_generator) + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps): + progress_bar.update() + + return latents + + + def _progress_bar(self, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError(f"_progress_bar_config should be dict, but is {type(self._progress_bar_config)}.") + + if total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("total has to be defined.") + + +def check_call_input(prompt, num_images_per_prompt, num_inference_steps, seed_generator): + if not isinstance(prompt, str): + raise ValueError("The input prompt type must be strings.") + if len(prompt) == 0 or len(prompt) >= MAX_PROMPT_LENGTH: + raise ValueError( + f"The length of the prompt should be (0, {MAX_PROMPT_LENGTH}), but got {len(prompt)}.") + if not isinstance(num_images_per_prompt, int): + raise ValueError("The input num_images_per_prompt type must be an instance of int.") + if num_images_per_prompt < 0: + raise ValueError( + f"Input num_images_per_prompt should be a non-negative integer, but got {num_images_per_prompt}.") + if not isinstance(num_inference_steps, int): + raise ValueError("The input num_inference_steps type must be an instance of int.") + if num_inference_steps < 0: + raise ValueError( + f"Input num_inference_steps should be a non-negative integer, but got {num_inference_steps}.") + if not isinstance(seed_generator, torch.Generator): + raise ValueError( + f"The type of input seed_generator must be torch.Generator, but got {type(seed_generator)}.") + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/schedulers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/schedulers/__init__.py new file mode 100644 index 0000000000..5eaa6e6c66 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/schedulers/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .ddpm import DDPMScheduler \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/schedulers/ddpm.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/schedulers/ddpm.py new file mode 100644 index 0000000000..39d3c42314 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/schedulers/ddpm.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import numpy as np + +from mindiesd import DiffusionScheduler +from ..utils import randn_tensor + + +class DDPMScheduler(DiffusionScheduler): + + def __init__( + self, + steps_offset: int = 1, + beta_start: float = 0.00085, + beta_end: float = 0.02, + num_train_timesteps: int = 1000, + ): + super().__init__() + + self.steps_offset = steps_offset + self.num_train_timesteps = num_train_timesteps + + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + + def set_timesteps(self, num_inference_steps: int = 100, device=None): + + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + self.num_inference_steps = num_inference_steps + + step_ratio = self.num_train_timesteps // self.num_inference_steps + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.steps_offset + + self.timesteps = torch.from_numpy(timesteps).to(device) + + + def step(self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None): + + prev_t = self._previous_timestep(timestep) + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + current_alpha_t = alpha_prod_t / alpha_prod_t_prev + current_beta_t = 1 - current_alpha_t + # 2. compute predicted original sample from predicted noise also called + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # 3. Compute coefficients for pred_original_sample x_0 and current sample x_t + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t + current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t + # 4. Compute predicted previous sample µ_t + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + # 5. Add noise + variance = 0 + if timestep > 0: + device = model_output.device + variance_noise = randn_tensor(model_output.shape, + generator=generator, + device=device, + dtype=model_output.dtype) + variance = (self._get_variance(timestep) ** 0.5) * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample + + + def _previous_timestep(self, timestep): + num_inference_steps = (self.num_inference_steps if self.num_inference_steps else self.num_train_timesteps) + prev_t = timestep - self.num_train_timesteps // num_inference_steps + return prev_t + + + def _get_variance(self, timestep): + prev_t = self._previous_timestep(timestep) + + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev + + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t + variance = torch.clamp(variance, min=1e-20) + + return variance \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/__init__.py new file mode 100644 index 0000000000..b17b7a5b78 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .utils import is_npu_available, postprocess_pil, set_seeds_generator, randn_tensor \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/file_utils.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/file_utils.py new file mode 100644 index 0000000000..f88e4148c1 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/file_utils.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + + +import os +from functools import reduce + +MAX_PATH_LENGTH = 4096 +MAX_FILE_SIZE = 10 * 1024 * 1024 * 1024 +SAFEOPEN_FILE_PERMISSION = 0o640 + +FLAG_OS_MAP = { + 'r': os.O_RDONLY, 'r+': os.O_RDWR, + 'w': os.O_CREAT | os.O_TRUNC | os.O_WRONLY, + 'w+': os.O_CREAT | os.O_TRUNC | os.O_RDWR, + 'a': os.O_CREAT | os.O_APPEND | os.O_WRONLY, + 'a+': os.O_CREAT | os.O_APPEND | os.O_RDWR, + 'x': os.O_CREAT | os.O_EXCL, + "b": getattr(os, "O_BINARY", 0) +} + + +def safe_open(file_path: str, mode='r', encoding=None, permission_mode=0o640, **kwargs): + """ + Args: + file_path (str): 文件路径 + mode (str): 文件打开模式 + encoding (str): 文件编码方式 + permission_mode: 文件权限最大值 + max_path_length (int): 文件路径最大长度 + max_file_size (int): 文件最大大小,单位: 字节, 默认值10MB + check_link (bool): 是否校验软链接 + kwargs: + """ + max_path_length = kwargs.get('max_path_length', MAX_PATH_LENGTH) + max_file_size = kwargs.get('max_file_size', MAX_FILE_SIZE) + check_link = kwargs.get('check_link', True) + + file_path = standardize_path(file_path, max_path_length, check_link) + check_file_safety(file_path, max_file_size, permission_mode) + + flags = [] + for item in list(mode): + if item == "+" and flags: + flags[-1] = f"{flags[-1]}+" + continue + flags.append(item) + flags = [FLAG_OS_MAP.get(mode, os.O_RDONLY) for mode in flags] + total_flag = reduce(lambda a, b: a | b, flags) + + return os.fdopen(os.open(file_path, total_flag, SAFEOPEN_FILE_PERMISSION), + mode, encoding=encoding) + + +def standardize_path(path: str, max_path_length=MAX_PATH_LENGTH, check_link=True): + """ + Check and standardize path. + Args: + path (str): 未标准化路径 + max_path_length (int): 文件路径最大长度 + check_link (bool): 是否校验软链接 + Return: + path (str): 标准化后的绝对路径 + """ + check_path_is_none(path) + check_path_length_lt(path, max_path_length) + if check_link: + check_path_is_link(path) + path = os.path.realpath(path) + return path + + +def is_path_exists(path: str): + return os.path.exists(path) + + +def check_path_is_none(path: str): + if path is None: + raise ValueError("The path should not be None.") + + +def check_path_is_link(path: str): + if os.path.islink(os.path.normpath(path)): + raise ValueError(f"The path:{path} is a symbolic link file.") + + +def check_path_length_lt(path: str, max_path_length=MAX_PATH_LENGTH): + if path.__len__() > max_path_length: + raise ValueError(f"The length of path is {path.__len__()}, which exceeds the limit {max_path_length}.") + + +def check_file_size_lt(path: str, max_file_size=MAX_FILE_SIZE): + if os.path.getsize(path) > max_file_size: + raise ValueError( + f"The size of file:{path} is {os.path.getsize(path)}, which exceeds the limit {max_file_size}.") + + +def check_owner(path: str): + path_stat = os.stat(path) + path_owner, path_gid = path_stat.st_uid, path_stat.st_gid + user_check = path_owner == os.getuid() and path_owner == os.geteuid() + if not (os.geteuid() == 0 or path_gid in os.getgroups() or user_check): + raise ValueError(f"The path:{path} is not owned by current user or root") + + +def check_max_permission(file_path: str, permission_mode=0o640): + # check permission + file_mode = os.stat(file_path).st_mode & 0o777 # use 777 as mask to get 3-digit octal number + # transeform file_mode into binary patten,remove the head '0b' string,expand to 9 bits + file_mode_bin = bin(file_mode)[2:].zfill(9) + # transeform permission_mode into binary patten,remove the head '0b' string,expand to 9 bits + max_mode_bin = bin(permission_mode)[2:].zfill(9) + for i in range(9): # 9 means 9-bit binary number, checking every bit + if file_mode_bin[i] > max_mode_bin[i]: + raise ValueError(f'The permission of {file_path} is higher than {oct(permission_mode)}') + + +def check_file_safety(file_path: str, max_file_size=MAX_FILE_SIZE, is_check_file_size=True, permission_mode=0o640): + if not is_path_exists(file_path): + raise ValueError(f"The path:{file_path} doesn't exist.") + if not os.path.isfile(file_path): + raise ValueError(f"The input:{file_path} is not a file.") + if is_check_file_size: + check_file_size_lt(file_path, max_file_size) + check_owner(file_path) + check_max_permission(file_path, permission_mode) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/utils.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/utils.py new file mode 100644 index 0000000000..89f87e88b4 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/utils/utils.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib +import random +import numpy as np +import torch +import PIL +from PIL import Image + + +def is_npu_available(): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + return False + + import torch_npu + + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False + + +def set_seeds_generator(seed, device=None): + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + return torch.Generator(device).manual_seed(seed) + + +def randn_tensor( + shape: tuple, + generator: torch.Generator = None, + device: torch.device = None, + dtype: torch.dtype = None, + layout: torch.layout = None, +): + """ + A helper function to create random tensors on the desired `device` with the desired `dtype`. When passing + a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + elif gen_device_type != device.type and gen_device_type == "npu": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents + + +def _denormalize(images): + return (images / 2 + 0.5).clamp(0, 1) + + +def _pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + return images + + +def _numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +def postprocess_pil(image: torch.Tensor): + if not isinstance(image, torch.Tensor): + raise ValueError(f"The input image type must be a torch.FloatTensor, but got {type(image)}.") + + image = torch.stack([_denormalize(image[i]) for i in range(image.shape[0])]) + image = _pt_to_numpy(image) + + return _numpy_to_pil(image) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/inference_hydit.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/inference_hydit.py new file mode 100644 index 0000000000..30d297c83f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/inference_hydit.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import random +import argparse +import time +import logging +import csv +import json + +import torch + +from diffusers import AutoencoderKL +from transformers import BertModel, BertTokenizer, T5EncoderModel, T5Tokenizer +from transformers.modeling_utils import logger as tf_logger + +from hydit import HunyuanDiTPipeline, HunyuanDiT2DModel, DDPMScheduler, set_seeds_generator +from hydit.utils import file_utils +from lora import multi_lora + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class PromptLoader: + def __init__( + self, + prompt_file: str, + prompt_file_type: str, + num_images_per_prompt: int = 1, + max_num_prompts: int = 0 + ): + self.prompts = [] + self.catagories = ['Not_specified'] + self.num_images_per_prompt = num_images_per_prompt + self.max_num_prompts = max_num_prompts + + if prompt_file_type == 'plain': + self.load_prompts_plain(prompt_file) + elif prompt_file_type == 'parti': + self.load_prompts_parti(prompt_file) + elif prompt_file_type == 'hpsv2': + self.load_prompts_hpsv2(prompt_file) + else: + print("This operation is not supported!") + + self.current_id = 0 + self.inner_id = 0 + + def __len__(self): + return len(self.prompts) * self.num_images_per_prompt + + def __iter__(self): + return self + + def __next__(self): + if self.current_id == len(self.prompts): + raise StopIteration + + ret = { + 'prompts': [], + 'catagories': [], + 'save_names': [], + } + for _ in range(self.num_images_per_prompt): + if self.current_id == len(self.prompts): + ret['prompts'].append('') + ret['save_names'].append('') + ret['catagories'].append('') + else: + prompt, catagory_id = self.prompts[self.current_id] + ret['prompts'].append(prompt) + ret['catagories'].append(self.catagories[catagory_id]) + ret['save_names'].append(f'{self.current_id}_{self.inner_id}') + + self.inner_id += 1 + if self.inner_id == self.num_images_per_prompt: + self.inner_id = 0 + self.current_id += 1 + return ret + + def load_prompts_plain(self, file_path: str): + with file_utils.safe_open(file_path, "r", encoding="utf-8", + permission_mode=file_utils.SAFEOPEN_FILE_PERMISSION) as file: + for i, line in enumerate(file): + if self.max_num_prompts and i == self.max_num_prompts: + break + + prompt = line.strip() + self.prompts.append((prompt, 0)) + + def load_prompts_parti(self, file_path: str): + with file_utils.safe_open(file_path, "r", encoding="utf-8", + permission_mode=file_utils.SAFEOPEN_FILE_PERMISSION) as file: + # Skip the first line + next(file) + tsv_file = csv.reader(file, delimiter="\t") + for i, line in enumerate(tsv_file): + if self.max_num_prompts and i == self.max_num_prompts: + break + + prompt = line[0] + catagory = line[1] + if catagory not in self.catagories: + self.catagories.append(catagory) + + catagory_id = self.catagories.index(catagory) + self.prompts.append((prompt, catagory_id)) + + def load_prompts_hpsv2(self, file_path: str): + with file_utils.safe_open(file_path, "r", encoding="utf-8", + permission_mode=file_utils.SAFEOPEN_FILE_PERMISSION) as file: + all_prompts = json.load(file) + count = 0 + for style, prompts in all_prompts.items(): + for prompt in prompts: + count += 1 + if self.max_num_prompts and count >= self.max_num_prompts: + break + + if style not in self.catagories: + self.catagories.append(style) + + catagory_id = self.catagories.index(style) + self.prompts.append((prompt, catagory_id)) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--path", type=str, default="ckpts/t2i", help="Path to the model directory") + parser.add_argument("--save_result_path", type=str, default="./results", help="Path to save result images") + parser.add_argument("--device_id", type=int, default=0, help="NPU device id") + parser.add_argument("--device", type=str, default="npu", help="NPU") + parser.add_argument("--prompt", type=str, default="渔舟唱晚", help="The prompt for generating images") + parser.add_argument("--test_acc", action="store_true", help="Run or not 'example_prompts.txt'") + parser.add_argument("--prompt_file", type=str, default="prompts/example_prompts.txt", help="The prompt list") + parser.add_argument("--prompt_file_type", choices=["plain", "parti", "hpsv2"], default="plain", + help="Type of prompt file") + parser.add_argument("--info_file_save_path", type=str, default="./image_info.json", + help="Path to save image information file") + parser.add_argument("--max_num_prompts", default=0, type=int, help="Limit the number of prompts (0: no limit)") + + parser.add_argument("--input_size", type=int, nargs="+", default=[1024, 1024], help="Image size (h, w)") + parser.add_argument("--type", type=str, default="fp16", help="The torch type is fp16 or bf16") + parser.add_argument("--batch_size", type=int, default=1, help="Per-NPU batch size") + parser.add_argument("--seed", type=int, default=42, help="A seed for all the prompts") + parser.add_argument("--infer_steps", type=int, default=100, help="Inference steps") + parser.add_argument("--guidance_scale", type=float, default=6.0, help="Guidance scale for classifier-free") + + parser.add_argument("--use_lora", action="store_true", help="Use LoRA checkpoint") + parser.add_argument("--lora_ckpt", type=str, default="ckpts/lora", help="LoRA checkpoint") + + parser.add_argument("--use_cache", action="store_true", help="Run or not using cache") + parser.add_argument("--step_start", type=int, default=9, help="The start iteration steps of cache") + parser.add_argument("--step_interval", type=int, default=2, help="The step interval of cache") + parser.add_argument("--block_start", type=int, default=5, help="The block start of cache") + parser.add_argument("--num_blocks", type=int, default=30, help="The num blocks of cache") + + parser.add_argument("--beta_end", type=float, default=0.02, help="Scheduler beta-end=0.03 if model<=1.1") + parser.add_argument("--use_style_cond", action="store_true", help="Use style condition. Only for model<=1.1") + parser.add_argument("--size_cond", type=int, nargs="+", default=None, + help="Size condition used in sampling. Default=[1024, 1024]. Only for model<=1.1") + return parser.parse_args() + + +def get_dtype(args): + dtype = torch.bfloat16 + if args.type == 'bf16': + dtype = torch.bfloat16 + elif args.type == 'fp16': + dtype = torch.float16 + else: + logger.error("Not supported.") + return dtype + + +def get_seed(args): + seed = args.seed + if seed is None: + seed = random.randint(0, 1_000_000) + if not isinstance(seed, int): + raise ValueError(f"The type of seed must be int, but got {type(seed)}.") + if seed < 0: + raise ValueError(f"Input seed must be a non-negative integer, but got {seed}.") + return set_seeds_generator(seed, device=args.device) + + +def get_save_path(args): + save_dir = args.save_result_path + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + now_time = time.localtime(time.time()) + time_dir_name = time.strftime("%m%d%H%M%S", now_time) + time_dir = os.path.join(save_dir, time_dir_name) + os.makedirs(time_dir) + logger.info(f"Save result image to {time_dir}") + return time_dir + + +def get_pipeline(args): + tf_logger.setLevel('ERROR') + if len(args.input_size) != 2: + raise ValueError(f"The length of args.input_size must be 2, but got {len(args.input_size)}") + input_size = (args.input_size[0], args.input_size[1]) + dtype = get_dtype(args) + + scheduler = DDPMScheduler(beta_end=args.beta_end) + + text_encoder_path = os.path.join(args.path, "clip_text_encoder") + text_encoder = BertModel.from_pretrained(text_encoder_path).to(args.device) + tokenizer_path = os.path.join(args.path, "tokenizer") + tokenizer = BertTokenizer.from_pretrained(tokenizer_path) + + mt5_path = os.path.join(args.path, "mt5") + text_encoder_2 = T5EncoderModel.from_pretrained(mt5_path).to(args.device).eval().to(dtype) + tokenizer_2 = T5Tokenizer.from_pretrained(mt5_path) + + vae_path = os.path.join(args.path, "sdxl-vae-fp16-fix") + vae = AutoencoderKL.from_pretrained(vae_path).to(args.device) + + transformer_path = os.path.join(args.path, "model") + transformer = HunyuanDiT2DModel.from_pretrained(transformer_path, + input_size=input_size, + size_cond=args.size_cond, + use_style_cond=args.use_style_cond, + dtype=dtype) + transformer = transformer.to(args.device).eval() + + pipeline = HunyuanDiTPipeline(scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + vae=vae, + args=args, + input_size=input_size) + return pipeline + + +def infer(args): + time_path = get_save_path(args) + seed_generator = get_seed(args) + pipeline = get_pipeline(args) + + if args.use_lora: + merge_state_dict = multi_lora(args, pipeline) + pipeline.transformer.load_state_dict(merge_state_dict) + + pipeline_total_time = 0.0 + infer_num = 0 + image_info = [] + current_prompt = None + prompt_loader = PromptLoader(args.prompt_file, args.prompt_file_type, args.batch_size, args.max_num_prompts) + if args.test_acc: + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + + start_time = time.time() + result_images = pipeline( + prompt=prompts[0], + num_images_per_prompt=args.batch_size, + num_inference_steps=args.infer_steps, + seed_generator=seed_generator, + )[0] + pipeline_time = time.time() - start_time + logger.info("HunyuanDiT [%d/%d] time: %.3f", infer_num + 1, len(prompt_loader), pipeline_time) + torch.npu.empty_cache() + + if infer_num >= (2 * args.batch_size): + pipeline_total_time += pipeline_time + infer_num += args.batch_size + + for j, img in enumerate(result_images): + save_path = os.path.join(time_path, f"{save_names[j]}.png") + img.save(save_path) + torch.npu.empty_cache() + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + image_info[-1]['images'].append(save_path) + + if infer_num <= (2 * args.batch_size): + raise ValueError(f"The number of prompts must be greater than {2*args.batch_size}, but got {infer_num}") + pipeline_average_time = pipeline_total_time / (infer_num - (2 * args.batch_size)) + logger.info("HunyuanDiT pipeline_average_time: %.3f", pipeline_average_time) + + # Save image information to a json file + if args.prompt_file_type != "plain": + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as file: + json.dump(image_info, file) + else: + prompts = args.prompt + prompts = [prompts.strip()] + result_images = pipeline( + prompt=prompts[0], + num_images_per_prompt=args.batch_size, + num_inference_steps=args.infer_steps, + seed_generator=seed_generator, + )[0] + torch.npu.empty_cache() + for i, img in enumerate(result_images): + save_path = os.path.join(time_path, f"0_{i}.png") + img.save(save_path) + torch.npu.empty_cache() + + +if __name__ == "__main__": + inference_args = parse_arguments() + if not os.path.exists(inference_args.path): + raise ValueError(f"The model path not exists: {inference_args.path}") + + torch.npu.set_device(inference_args.device_id) + infer(inference_args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/lora/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/lora/__init__.py new file mode 100644 index 0000000000..a01b6bc70a --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/lora/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .hydit_lora import multi_lora \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/lora/hydit_lora.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/lora/hydit_lora.py new file mode 100644 index 0000000000..8ce6476714 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/lora/hydit_lora.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from safetensors import safe_open + + +def multi_lora(args, pipeline): + transformer_state_dict = pipeline.transformer.state_dict() + lora_state_dict = {} + with safe_open(args.lora_ckpt, framework="pt", device=args.device) as f: + for k in f.keys(): + lora_state_dict[k[17:]] = f.get_tensor(k) + + num_blocks = pipeline.transformer.config.depth + merge_state_dict = load_lora(transformer_state_dict, lora_state_dict, num_blocks, lora_scale=1.0) + return merge_state_dict + + +def load_lora(transformer_state_dict, lora_state_dict, num_blocks, lora_scale): + + for i in range(num_blocks): + Wqkv = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_B.weight"], + lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_A.weight"]) + transformer_state_dict[f"blocks.{i}.attn1.qkv_proj.weight"] += lora_scale * Wqkv + + out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_B.weight"], + lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_A.weight"]) + transformer_state_dict[f"blocks.{i}.attn1.out_proj.weight"] += lora_scale * out_proj + + q_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_B.weight"], + lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_A.weight"]) + transformer_state_dict[f"blocks.{i}.attn2.q_proj.weight"] += lora_scale * q_proj + + kv_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_B.weight"], + lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_A.weight"]) + transformer_state_dict[f"blocks.{i}.attn2.kv_proj.weight"] += lora_scale * kv_proj + + out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_B.weight"], + lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_A.weight"]) + transformer_state_dict[f"blocks.{i}.attn2.out_proj.weight"] += lora_scale * out_proj + + q_proj = torch.matmul(lora_state_dict["pooler.q_proj.lora_B.weight"], + lora_state_dict["pooler.q_proj.lora_A.weight"]) + transformer_state_dict["pooler.attn.q_proj.weight"] += lora_scale * q_proj + + return transformer_state_dict diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/prompts/example_prompts.txt b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/prompts/example_prompts.txt new file mode 100644 index 0000000000..f590be43c4 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/prompts/example_prompts.txt @@ -0,0 +1,28 @@ +一只聪明的狐狸走在阔叶树林里, 旁边是一条小溪, 细节真实, 摄影 +湖水清澈,天空湛蓝,阳光灿烂。一只优雅的白天鹅在湖边游泳。它周围有几只小鸭子,看起来非常可爱,整个画面给人一种宁静祥和的感觉。 +太阳微微升起,花园里的玫瑰花瓣上露珠晶莹剔透,一只瓢虫正在爬向露珠,背景是清晨的花园,微距镜头 +一位女明星,中国人,头发是黑色,衣服是纯白色短袖,人物风格清新,城市背景 +后印象主义风格,一条古老的石板路上面散落着金黄色的树叶。路旁的风车在静谧地转动,后面竖着两个风车。背景是一片向日葵田,蓝天上飘着几朵白云 +一幅细致的油画描绘了一只年轻獾轻轻嗅着一朵明亮的黄色玫瑰时错综复杂的皮毛。背景是一棵大树干的粗糙纹理,獾的爪子轻轻地挖进树皮。在柔和的背景中,一个宁静的瀑布倾泻而下,它的水在绿色植物中闪烁着蓝色。 +渔舟唱晚 +请将杞人忧天的样子画出来 +一只长靴猫手持亮银色的宝剑,身着铠甲,眼神坚毅,站在一堆金币上,背景是暗色调的洞穴,图像上有金币的光影点缀。 +插画风格,一只狐狸和一只刺猬坐在水边的石头上,刺猬手里拿着一杯茶,狐狸旁边放着一个玻璃杯。周围是茂密的绿色植物和树木,阳光透过树叶洒在水面上,画面宁静温馨。 +泥塑风格,一座五彩斑斓的花园在画面中展现,各种各样的花朵,绿色的叶子和一只正在嬉戏的小猫形成了一幅生动的图像,背景是蓝天和白云 +枯藤老树昏鸦,小桥流水人家 +一张细致的照片捕捉到了一尊雕像的形象,这尊雕像酷似一位古代法老,头上出人意料地戴着一副青铜蒸汽朋克护目镜。这座雕像穿着复古时髦,一件清爽的白色T恤和一件合身的黑色皮夹克,与传统的头饰形成鲜明对比。背景是简单的纯色,突出了雕像的非传统服装和蒸汽朋克眼镜的复杂细节。 +一朵鲜艳的红色玫瑰花,花瓣撒有一些水珠,晶莹剔透,特写镜头, +一只可爱的猫, 细节真实, 摄影 +飞流直下三千尺,疑是银河落九天 +成语“鲤鱼跃龙门” +一颗新鲜的草莓特写,红色的外表,表面布满许多种子,背景是淡绿色的叶子 +九寨沟 +摄影风格,在画面中心是一盘热气腾腾的麻婆豆腐,豆腐呈白色,上面撒着一层红色的辣酱,有些许绿色的葱花点缀,背景是深色木质餐桌,桌子上放有辣椒和葱花作为点缀。 +一位年轻女子站在春季的火车站月台上。她身着蓝灰色长风衣,白色衬衫。她的深棕色头发扎成低马尾,几缕碎发随风飘扬。她的眼神充满期待,阳光洒在她温暖的脸庞上。 +一只优雅的白鹤在湖边静静地站立,它的身体纯白色,翅膀轻轻展开,背景是湖面和远处的山脉 +国画风格,苏州园林中的小桥流水,周围是郁郁葱葱的树,池塘里有几朵绽放的荷花,背景是宁静的江南水乡 +现实主义风格,画面主要描述一个巴洛克风格的花瓶,带有金色的装饰边框,花瓶上盛开着各种色彩鲜艳的花,白色背景 +醉后不知天在水,满船清梦压星河 +长城 +一个亚洲中年男士在夕阳下的公园长椅上静坐。他穿着一件深蓝色的针织毛衣和灰色裤子。他的头发略显花白,手中拿着一本敞开的书。面带微笑,眼神温和,周围是落日余晖和四周的绿树。 +风格是写实,画面主要描述一个亚洲戏曲艺术家正在表演,她穿着华丽的戏服,脸上戴着精致的面具,身姿优雅,背景是古色古香的舞台,镜头是近景 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/requirents.txt b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/requirents.txt new file mode 100644 index 0000000000..69d3b464b3 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/requirents.txt @@ -0,0 +1,18 @@ +colossalai==0.4.0 +strenum==0.4.15 +accelerate==0.29.3 +diffusers==0.26.3 +einops==0.7.0 +gradio==3.50.2 +huggingface-hub==0.24.7 +Jinja2==3.1.4 +matplotlib==3.9.2 +numpy==1.26.4 +peft==0.10.0 +safetensors==0.4.5 +timm==0.9.5 +tqdm==4.66.5 +torch==2.1.0 +torchvision==0.16.0 +tokenizers==0.15.2 +transformers==4.39.3 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/inference_opensoraplan13.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/inference_opensoraplan13.py new file mode 100644 index 0000000000..f03426fa76 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/inference_opensoraplan13.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import time +import logging + +import torch +import torch_npu +import imageio + +from transformers import AutoTokenizer, MT5EncoderModel +from open_sora_planv1_3.pipeline.open_soar_plan_pipeline import OpenSoraPlanPipeline13 +from open_sora_planv1_3.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler +from open_sora_planv1_3.models.t2vdit import OpenSoraT2Vv1_3 +from open_sora_planv1_3.models.wfvae import WFVAEModelWrapper, ae_stride_config +from open_sora_planv1_3.utils import set_random_seed +from open_sora_planv1_3.models.parallel_mgr import init_parallel_env, get_sequence_parallel_rank +from open_sora_planv1_3.layers.cache_mgr import CacheManager, DitCacheConfig + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='Test Pipeline Argument Parser') + + parser.add_argument('--model_path', type=str, required=True, help='Path to the model directory') + parser.add_argument('--version', type=str, default='v1_3', help='Version of the model') + parser.add_argument('--dtype', type=str, default='fp16', help='Data type used in inference') + parser.add_argument('--num_frames', type=int, default=93, help='Number of frames') + parser.add_argument('--height', type=int, default=720, help='Height of the frames') + parser.add_argument('--width', type=int, default=1280, help='Width of the frames') + parser.add_argument('--text_encoder_name_1', type=str, required=True, help='Path to the text encoder model') + parser.add_argument('--text_prompt', type=str, required=True, help='Text prompt for the model') + parser.add_argument('--ae', type=str, default='WFVAEModel_D8_4x8x8', help='Autoencoder model type') + parser.add_argument('--ae_path', type=str, required=True, help='Path to the autoencoder model') + parser.add_argument('--save_img_path', type=str, default='./test', help='Path to save images') + parser.add_argument('--fps', type=int, default=24, help='Frames per second') + parser.add_argument('--guidance_scale', type=float, default=7.5, help='Guidance scale for the model') + parser.add_argument('--num_sampling_steps', type=int, default=10, help='Number of sampling steps') + parser.add_argument('--max_sequence_length', type=int, default=512, help='Maximum sequence length') + parser.add_argument('--seed', type=int, default=1234, help='Random seed') + parser.add_argument('--num_samples_per_prompt', type=int, default=1, help='Number of samples per prompt') + parser.add_argument('--rescale_betas_zero_snr', action='store_true', help='Rescale betas zero SNR') + parser.add_argument('--prediction_type', type=str, default='v_prediction', help='Type of prediction') + parser.add_argument('--save_memory', action='store_true', help='Save memory during processing') + parser.add_argument('--enable_tiling', action='store_true', help='Enable tiling for processing') + parser.add_argument('--sp', action='store_true') + parser.add_argument('--use_cache', action='store_true') + parser.add_argument('--cache_sampling_step_start', type=int, default=20, help='Sampling step begins to use cache') + parser.add_argument('--cache_sampling_step_interval', type=int, default=2, help='Sampling step interval of cache') + parser.add_argument('--cache_dit_block_start', type=int, default=2, help='DiT block id begins to be cached') + parser.add_argument('--cache_num_dit_blocks', type=int, default=20, help='DiT blocks cached in each step') + args = parser.parse_args() + return args + + +def infer(args): + dtype = torch.bfloat16 + if args.dtype == 'bf16': + dtype = torch.bfloat16 + elif args.dtype == 'fp16': + dtype = torch.float16 + else: + logger.error("Not supported.") + # === Initialize Distributed === + init_parallel_env(args.sp) + + set_random_seed(args.seed + get_sequence_parallel_rank()) + + negative_prompt = """ + nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, + low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. + """ + positive_prompt = """ + high quality, high aesthetic, {} + """ + if not os.path.exists(args.save_img_path): + os.makedirs(args.save_img_path, exist_ok=True) + + if not isinstance(args.text_prompt, list): + args.text_prompt = [args.text_prompt] + if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'): + text_prompt = open(args.text_prompt[0], 'r').readlines() + args.text_prompt = [i.strip() for i in text_prompt] + + vae = WFVAEModelWrapper.from_pretrained(args.ae_path, dtype=torch.float16).to("npu").eval() + vae.vae_scale_factor = ae_stride_config[args.ae] + transformer = OpenSoraT2Vv1_3.from_pretrained(args.model_path).to(dtype).to("npu").eval() + + kwargs = dict( + prediction_type=args.prediction_type, + rescale_betas_zero_snr=args.rescale_betas_zero_snr, + timestep_spacing="trailing" if args.rescale_betas_zero_snr else 'leading', + ) + scheduler = EulerAncestralDiscreteScheduler(**kwargs) + text_encoder = MT5EncoderModel.from_pretrained(args.text_encoder_name_1, + torch_dtype=dtype).eval().to(dtype).to("npu") + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name_1) + + if args.save_memory: + vae.vae.enable_tiling() + vae.vae.t_chunk_enc = 8 + vae.vae.t_chunk_dec = 2 + + pipeline = OpenSoraPlanPipeline13(vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler) + + if args.use_cache: + config = DitCacheConfig(step_start=20, step_interval=2, block_start=2, num_blocks=20) + cache = CacheManager(config) + pipeline.transformer.cache = cache + + with torch.no_grad(): + for i, input_prompt in enumerate(args.text_prompt): + input_prompt = positive_prompt.format(input_prompt) + start_time = time.time() + videos = pipeline( + input_prompt, + negative_prompt=negative_prompt, + num_frames=args.num_frames, + height=args.height, + width=args.width, + num_inference_steps=args.num_sampling_steps, + guidance_scale=args.guidance_scale, + num_samples_per_prompt=args.num_samples_per_prompt, + max_sequence_length=args.max_sequence_length, + )[0] + torch.npu.synchronize() + use_time = time.time() - start_time + logger.info("use_time: %.3f", use_time) + imageio.mimwrite( + os.path.join( + args.save_img_path, + f's{args.num_sampling_steps}_prompt{i}.mp4' + ), + videos[0], + fps=args.fps, + quality=6 + ) # highest quality is 10, lowest is 0 + +if __name__ == "__main__": + inference_args = parse_arguments() + infer(inference_args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/__init__.py new file mode 100644 index 0000000000..0de74ab931 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/__init__.py @@ -0,0 +1,6 @@ +from .wavelet import (HaarWaveletTransform2D, HaarWaveletTransform3D, + InverseHaarWaveletTransform2D, InverseHaarWaveletTransform3D) +from .conv import AttnBlock3DFix +from .mlp import Mlp +from .norm import AdaLayerNorm, VideoLayerNorm, Normalize +from .sampling import Upsample, Downsample, Spatial2xTime2x3DDownsample, Spatial2xTime2x3DUpsample \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/activation.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/activation.py new file mode 100644 index 0000000000..34c26d1b02 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/activation.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn +import torch_npu + + +class GEGLU(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, hidden_states): + + return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] + + +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU(), + "gelu-approximate": nn.GELU(approximate="tanh"), + "geglu": GEGLU() +} + + + +def get_activation_fn(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/attention.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/attention.py new file mode 100644 index 0000000000..61dcaf9204 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/attention.py @@ -0,0 +1,392 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import inspect +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.nn.functional as F +import torch_npu +from einops import rearrange, repeat + +from .norm import get_normalization_helper +from .embedding import RoPE3D, PositionGetter3D, get_embedding_helper +from ..models.parallel_mgr import get_sequence_parallel_state, get_sequence_parallel_size +from ..models.comm import all_to_all_sbh +from .linear import QKVLinear + +ALIGNMENT_BASE = 16 + + +class ReconstitutionAttention(nn.Module): + r""" + Attention layer. + """ + def __init__( + self, + attention_dim: int, + cross_attention_dim: Optional[int] = None, + num_heads: int = 8, + head_dim: int = 64, + qkv_bias: bool = True, + out_proj_bias: bool = True, + num_norm_groups: Optional[int] = None, + attention_norm: Optional[str] = None, + position_embedding: Optional[str] = None, + add_proj_dim: Optional[int] = None, + add_proj_bias: bool = True, + enable_add_out_proj: bool = True, + scale_attention: bool = True, + eps: float = 1e-5, + processor: Optional["AttnProcessor"] = None, + ): + r""" + Attention layer init function. + Args: + attention_dim (`int`): + The number of channels in the hidden_states. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the encoder_hidden_states. If not `None`, means cross attention. + num_heads (`int`, *optional*, defaults to 8): + The number of attention heads used in the multi-head attention layers. + head_dim (`int`, *optional*, defaults to 64): + The number of dims in each head. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not the quert, key and value linear layer to contain a bias parameter. + out_proj_bias (`bool`, *optional*, defaults to `True`): + Whether or not the out projection layer to contain a bias parameter. + num_norm_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the `GroupNorm` in attention. + If `None`, no `GroupNorm` is used. + attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the query and key in attention. + Can be `None`, `layer_norm`, or `llama_rms_norm`. + position_embedding (`str`, *optional*, defaults to `None`): + The type of position embdding to use for the query and key in attention. Can be `None`, `rope`. + add_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the additional projections. If `None`, no projection is used. + add_proj_bias (`bool`, *optional*, defaults to `True`): + Whether or not the additional projection layer to contain a bias parameter. + enable_add_out_proj (`bool`, *optional*, defaults to `True`): + Whether or not use the additional out projection. + scale_attention (`bool`, *optional*, defaults to `True`): + Set `True` to scale the query @ key result with by `1 / sqrt(head_dim)`. + eps (`float`, *optional*, defaults to 1e-5): + The additional value added to eh denominator in normalization. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, `AttnProcessor` will be used. + """ + super().__init__() + + self.num_heads = num_heads + if head_dim <= 0: + raise ValueError(f"Input head_dim should be greater than zero, but got {head_dim}.") + self.head_dim = head_dim + self.pad_dim = self.head_dim + + self.is_cross_attention = cross_attention_dim is not None + + self.scale_value = 1 / math.sqrt(head_dim) if scale_attention else 1.0 + + # `hidden_size` is calculated by num_heads * head_dim -> H = N * D + hidden_size = num_heads * head_dim + + hidden_size = self._set_pad(hidden_size, num_norm_groups, attention_norm, position_embedding) + + # Normalize hidden states by group_norm + self.group_norm = nn.GroupNorm(num_channels=hidden_size, num_groups=num_norm_groups, eps=eps, affine=True) \ + if num_norm_groups is not None else None + + # Init normalization layer by get_normalization_helper. + self.norm_q = get_normalization_helper(attention_norm, head_dim, eps) + self.norm_k = get_normalization_helper(attention_norm, head_dim, eps) + + # Init position embedding by get_embedding_helper. + self.position_embedding = get_embedding_helper(position_embedding, head_dim) + + # QKVLinear + if self.is_cross_attention: + self.qkv_proj = QKVLinear(attention_dim, hidden_size, qkv_bias, cross_attention_dim) + else: + self.qkv_proj = QKVLinear(attention_dim, hidden_size, qkv_bias) + + # Additional qkv linear for Multi-Modal Diffusion Transformer + if add_proj_dim is not None: + self.add_qkv_proj = nn.Linear(add_proj_dim, 3 * hidden_size, bias=add_proj_bias) # 3: qkv + + # OutLinear + self.out_proj = nn.Linear(hidden_size, attention_dim, bias=out_proj_bias) + + # Additional out linear for Multi-Modal Diffusion Transformer + if add_proj_dim is not None: + # For the last attention layer in Multi-Modal Diffusion Transformer, + # no need to calculate the additional out linear + self.add_out_proj = nn.Linear(hidden_size, attention_dim, bias=out_proj_bias) \ + if enable_add_out_proj else nn.Identity() + + # Set default processor by AttnProcessor + attn_processor = processor if processor is not None else AttnProcessor() + self.set_processor(attn_processor) + + def set_processor(self, processor: "AttnProcessor"): + """ + Set the attention processor. + Users can develop different attention processor for `Attention` to achieve different functions. + Args: + processor: ("AttnProcessor"): + The attention processor to used for attention forward. + """ + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Attention forward function. + Args: + hidden_states (`torch.Tensor`): + The hidden states of attention query. + encoder_hidden_states (`torch.Tensor`, *optional*, defaults to `None`): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + The mask of attention. + **kwargs: + The additional arguments to the attention processors. + For standard attention use `AttnProcessor`, kwargs is empty. + Returns: + `torch.Tensor`: The output of the attention layer. + """ + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + attn_kwargs = {key: value for key, value in kwargs.items() if key in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **attn_kwargs + ) + + def _set_pad(self, hidden_size, num_norm_groups, attention_norm, position_embedding): + if self.head_dim % ALIGNMENT_BASE == 0: + return hidden_size + elif (num_norm_groups is not None) or (attention_norm is not None) or (position_embedding is not None): + return hidden_size + else: + self.pad_dim = (self.head_dim // ALIGNMENT_BASE + 1) * ALIGNMENT_BASE + hidden_size = self.pad_dim * self.num_heads + return hidden_size + + +class AttnProcessor: + """ + The standard attention processor. + """ + def __call__( + self, + attn: ReconstitutionAttention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + if hidden_states is None: + raise ValueError("Input hidden_states should not be none.") + + # only support BNC now. + if hidden_states.ndim != 3: # 3: BNC + raise ValueError(f"The dimensions of hidden_states should be 3, but got {hidden_states.ndim}") + + batch_size = hidden_states.shape[0] + + if attn.group_norm is not None: + # In `BSH`, `H` represents channel, so it needs to be transposed. + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if attn.is_cross_attention: + query = attn.q_proj(hidden_states) + query = query.reshape(batch_size, -1, attn.num_heads, attn.head_dim).transpose(1, 2) # B S N D -> B N S D + + kv = attn.kv_proj(encoder_hidden_states) + kv = kv.reshape(batch_size, -1, 2, attn.num_heads, attn.head_dim) + key, value = kv.permute(2, 0, 3, 1, 4).unbind(0) # B S 2 N D -> 2 B N S D -> 2 * B N S D + else: + qkv = attn.qkv_proj(hidden_states) + qkv = qkv.reshape(batch_size, -1, 3, attn.num_heads, attn.head_dim) # 3: q,k,v + query, key, value = qkv.permute(2, 0, 3, 1, 4).unbind(0) # B S 3 N D -> 3 B N S D -> 3 * B N S D + + hidden_states = torch_npu.npu_prompt_flash_attention( + query, key, value, + num_heads=query.shape[1], + input_layout="BNSD", + atten_mask=attention_mask, + scale_value=attn.scale_value) + # transform the hidden_states layout from BNSD to BSH + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.num_heads * attn.head_dim) + hidden_states = attn.out_proj(hidden_states) + return hidden_states + + +class OpenSoraPlanAttnProcessor: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, interpolation_scale_thw=(1, 1, 1), + sparse1d=False, sparse_n=2, sparse_group=False, is_cross_attn=True): + self.sparse1d = sparse1d + self.sparse_n = sparse_n + self.sparse_group = sparse_group + self.is_cross_attn = is_cross_attn + self.interpolation_scale_thw = interpolation_scale_thw + + self._init_rope(interpolation_scale_thw) + + def __call__( + self, + attn, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + frame: int = 8, + height: int = 16, + width: int = 16, + *args, + **kwargs, + ) -> torch.FloatTensor: + _, batch_size, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape) + + if attn.is_cross_attention: + query, key, value = attn.qkv_proj(hidden_states, encoder_hidden_states) + else: + query, key, value = attn.qkv_proj(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.num_heads + fa_head_num = attn.num_heads + total_frame = frame + + if get_sequence_parallel_state(): + sp_size = get_sequence_parallel_size() + fa_head_num = attn.num_heads // sp_size + total_frame = frame * sp_size + # apply all_to_all to gather sequence and split attention heads [s // sp * b, h, d] -> [s * b, h // sp, d] + query = all_to_all_sbh(query.view(-1, attn.num_heads, head_dim), scatter_dim=1, gather_dim=0) + key = all_to_all_sbh(key.view(-1, attn.num_heads, head_dim), scatter_dim=1, gather_dim=0) + value = all_to_all_sbh(value.view(-1, attn.num_heads, head_dim), scatter_dim=1, gather_dim=0) + query = query.view(-1, batch_size, fa_head_num, head_dim) + key = key.view(-1, batch_size, fa_head_num, head_dim) + + if not self.is_cross_attn: + # require the shape of (ntokens x batch_size x nheads x dim) + pos_thw = self.position_getter(batch_size, t=total_frame, h=height, w=width, device=query.device) + + query = self.rope(query, pos_thw) + key = self.rope(key, pos_thw) + + query = query.view(-1, batch_size, fa_head_num * head_dim) + key = key.view(-1, batch_size, fa_head_num * head_dim) + value = value.view(-1, batch_size, fa_head_num * head_dim) + if self.sparse1d: + query, pad_len = self._sparse_1d(query, total_frame, height, width) + if self.is_cross_attn: + key = self._sparse_1d_kv(key) + value = self._sparse_1d_kv(value) + else: + key, pad_len = self._sparse_1d(key, total_frame, height, width) + value, pad_len = self._sparse_1d(value, total_frame, height, width) + + rearrange_method = 's b (h d) -> b h s d' + # .contiguous() not need + query = rearrange(query, rearrange_method, h=fa_head_num) + key = rearrange(key, rearrange_method, h=fa_head_num) + value = rearrange(value, rearrange_method, h=fa_head_num) + + hidden_states = torch_npu.npu_fused_infer_attention_score(query, key, value, + atten_mask=attention_mask, input_layout="BNSD", scale=1 / math.sqrt(head_dim), + num_heads=fa_head_num)[0] + + hidden_states = rearrange(hidden_states, 'b h s d -> s b (h d)', h=fa_head_num).contiguous() + + if self.sparse1d: + hidden_states = self._reverse_sparse_1d( + hidden_states, total_frame, height, width, pad_len) + + # [s, b, h // sp * d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] + if get_sequence_parallel_state(): + hidden_states = all_to_all_sbh(hidden_states.reshape(-1, fa_head_num, head_dim), + scatter_dim=0, gather_dim=1) + hidden_states = hidden_states.view(-1, batch_size, inner_dim) + + hidden_states = hidden_states.to(query.dtype) + # linear proj + hidden_states = attn.out_proj(hidden_states) + return hidden_states + + def _init_rope(self, interpolation_scale_thw): + self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw) + self.position_getter = PositionGetter3D() + + def _sparse_1d(self, x, frame, height, width): + """ + require the shape of (ntokens x batch_size x dim) + """ + seqlen = x.shape[0] + if seqlen != frame * height * width: + raise ValueError(f"x.shape[0] should be equal to frame*height*width") + pad_len = 0 + if seqlen % (self.sparse_n * self.sparse_n) != 0: + pad_len = self.sparse_n * self.sparse_n - seqlen % (self.sparse_n * self.sparse_n) + if pad_len != 0: + x = F.pad(x, (0, 0, 0, 0, 0, pad_len)) + if not self.sparse_group: + x = rearrange(x, '(g k) b d -> g (k b) d', k=self.sparse_n) + else: + x = rearrange(x, '(n m k) b d -> (n k) (m b) d', m=self.sparse_n, k=self.sparse_n) + return x, pad_len + + def _reverse_sparse_1d(self, x, frame, height, width, pad_len): + """ + require the shape of (ntokens x batch_size x dim) + """ + if x.shape[0] != (frame * height * width + pad_len) // self.sparse_n: + raise ValueError("x.shape[0] should be equal to" + f"f{(frame * height * width + pad_len) // self.sparse_n}") + if not self.sparse_group: + x = rearrange(x, 'g (k b) d -> (g k) b d', k=self.sparse_n) + else: + x = rearrange(x, '(n k) (m b) d -> (n m k) b d', m=self.sparse_n, k=self.sparse_n) + x = x[:frame * height * width, :, :] + return x + + def _sparse_1d_kv(self, x): + """ + require the shape of (ntokens x batch_size x dim) + """ + x = repeat(x, 's b d -> s (k b) d', k=self.sparse_n) + return x \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/cache_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/cache_mgr.py new file mode 100644 index 0000000000..86585f0d02 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/cache_mgr.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class CacheConfig(): + def __init__(self, method=None): + self.method = method + + +class CacheAgentConfig(CacheConfig): + """ + The DitCache Config. + """ + def __init__(self, policy_dir: str): + """ + Args: + policy_dir: The file containing the policy. + """ + super().__init__(method="CacheAgent") + self.policy_dir = policy_dir + + +class DitCacheConfig(CacheConfig): + """ + The DitCache Config. + """ + def __init__(self, step_start: int, step_interval: int, block_start: int, num_blocks: int): + """ + Args: + step_start: The starting step for caching. + step_interval: The interval at which caching should occur. + block_start: The starting block index for caching. + num_blocks: The number of blocks to cache. + """ + super().__init__(method="DitCache") + self.step_start = step_start + self.step_interval = step_interval + self.block_start = block_start + self.num_blocks = num_blocks + + +class CacheManager: + """ + The CacheManager class is interface to manage the cache algorithm. + """ + def __init__( + self, + config:CacheConfig + ): + """ + Args: + config: The configuration for the cache algorithm. + """ + if isinstance(config, CacheConfig): + self.method = config.method + self.cache_cls = Cache_cls[self.method](**vars(config)) + + def __call__(self, block, time_step, block_idx, hidden_states, *args, **kwargs + ): + """ + Args: + block: The block in the DiT module. + time_step: The current time step. + block_idx: The index of the block. + hidden_states: The hidden states. + *args: Additional arguments. + **kwargs: Additional keyword arguments. + """ + if not self._use_cache(time_step, block_idx): + old_hidden_states = hidden_states + if isinstance(block, list): + for blk in block: + hidden_states = blk(hidden_states, *args, **kwargs) + else: + hidden_states = block(hidden_states, *args, **kwargs) + self._update_cache(hidden_states, old_hidden_states, time_step, block_idx) + else: + hidden_states += self._get_cache(time_step, block_idx) + return hidden_states + + def _use_cache(self, time_step, block_idx): + return self.cache_cls.use_cache(time_step, block_idx) + + def _get_cache(self, time_step, block_idx): + return self.cache_cls.get_cache(time_step, block_idx) + + def _update_cache(self, hidden_states, old_hidden_states, time_step, block_idx): + self.cache_cls.update_cache(hidden_states, old_hidden_states, time_step, block_idx) + + +class CacheAgent(): + def __init__(self, policy_dir, **kwargs): + self.policy_dir = policy_dir + self.cache = [None for _ in range(32)] + + self.policy = torch.load(policy_dir) + + def use_cache(self, time_step, block_idx): + if time_step == 0: + return False + else: + return self.policy[time_step - 1, block_idx] + + def get_cache(self, time_step, block_idx): + return self.cache[block_idx] + + def update_cache(self, hidden_states, old_hidden_states, time_step, block_idx): + delta = hidden_states - old_hidden_states + self.cache[block_idx] = delta + + +class DitCache: + def __init__(self, step_start, step_interval, block_start, num_blocks, **kwargs): + self.step_start = step_start + self.step_interval = step_interval + self.block_start = block_start + self.num_blocks = num_blocks + self.block_end = block_start + num_blocks - 1 + self.cache = None + self.time_cache = {} + + def use_cache(self, time_step, block_idx): + if time_step < self.step_start: + return False + else: + diftime = time_step - self.step_start + if diftime not in self.time_cache: + self.time_cache[diftime] = diftime % self.step_interval == 0 + if self.time_cache[diftime]: + return False + elif block_idx < self.block_start or block_idx > self.block_end: + return False + else: + return True + + def get_cache(self, time_step, block_idx): + if block_idx == self.block_start: + return self.cache + else: + return 0 + + def update_cache(self, hidden_states, old_hidden_states, time_step, block_idx): + diftime = time_step - self.step_start + # when (time_step - self.step_start) % self.step_interval == 0: + if time_step >= self.step_start and self.time_cache[diftime]: + if block_idx == self.block_start: + self.cache = old_hidden_states + elif block_idx == self.block_end: + self.cache = hidden_states - self.cache + + +Cache_cls = { + "CacheAgent" : CacheAgent, + "DitCache" : DitCache +} diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/conv.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/conv.py new file mode 100644 index 0000000000..f9f55db5c1 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/conv.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple, Union +from collections import deque + +import torch +import torch.nn as nn +import torch_npu +from .utils import video_to_image, cast_tuple +from .norm import Normalize + + +class VideoConv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]] = 3, + stride: Union[int, Tuple[int]] = 1, + padding: Union[str, int, Tuple[int]] = 0, + dilation: Union[int, Tuple[int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + @video_to_image + def forward(self, x): + return super().forward(x) + + +class PlanCausalConv3d(nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + enable_cached=False, + bias=True, + **kwargs, + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.time_kernel_size = self.kernel_size[0] + self.chan_in = chan_in + self.chan_out = chan_out + self.stride = kwargs.pop("stride", 1) + self.padding = kwargs.pop("padding", 0) + self.padding = list(cast_tuple(self.padding, 3)) + self.padding[0] = 0 + self.stride = cast_tuple(self.stride, 3) + self.conv = nn.Conv3d( + chan_in, + chan_out, + self.kernel_size, + stride=self.stride, + padding=self.padding, + bias=bias + ) + self.enable_cached = enable_cached + self.is_first_chunk = True + + self.causal_cached = deque() + self.cache_offset = 0 + + def forward(self, x): + if self.is_first_chunk: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.time_kernel_size - 1, 1, 1) + ) + else: + first_frame_pad = self.causal_cached.popleft() + x = torch.concatenate((first_frame_pad, x), dim=2) + + if self.enable_cached and self.time_kernel_size != 1: + if (self.time_kernel_size - 1) // self.stride[0] != 0: + if self.cache_offset == 0: + self.causal_cached.append(x[:, :, -(self.time_kernel_size - 1) // self.stride[0]:].clone()) + else: + self.causal_cached.append(x[ + :, :, :-self.cache_offset][:, :, -(self.time_kernel_size - 1) // self.stride[0]:].clone()) + else: + self.causal_cached.append(x[:, :, 0:0, :, :].clone()) + elif self.enable_cached: + self.causal_cached.append(x[:, :, 0:0, :, :].clone()) + + x = self.conv(x) + return x + + +class AttnBlock3DFix(nn.Module): + + def __init__(self, in_channels, norm_type="groupnorm"): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels, norm_type=norm_type) + self.q = PlanCausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.k = PlanCausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.v = PlanCausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.proj_out = PlanCausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, t, h, w = q.shape + q = q.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous() + k = k.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous() + v = v.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous() + + attn_output = torch_npu.npu_fused_infer_attention_score(q, k, v, + atten_mask=None, input_layout="BSH", scale=1 / math.sqrt(c), + num_heads=1)[0] + + attn_output = attn_output.reshape(b, t, h, w, c).permute(0, 4, 1, 2, 3) + h_ = self.proj_out(attn_output) + + return x + h_ \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/linear.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/linear.py new file mode 100644 index 0000000000..e3a6571f4b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/linear.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +import torch_npu + + +class QKVLinear(nn.Module): + def __init__(self, attention_dim, hidden_size, qkv_bias=True, cross_attention_dim=None, device=None, dtype=None): + super(QKVLinear, self).__init__() + self.attention_dim = attention_dim + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.qkv_bias = qkv_bias + + factory_kwargs = {"device": device, "dtype": dtype} + + if not cross_attention_dim: + self.weight = nn.Parameter(torch.empty([self.attention_dim, 3 * self.hidden_size], **factory_kwargs)) + if self.qkv_bias: + self.bias = nn.Parameter(torch.empty([3 * self.hidden_size], **factory_kwargs)) + else: + self.q_weight = nn.Parameter(torch.empty([self.attention_dim, self.hidden_size], **factory_kwargs)) + self.kv_weight = nn.Parameter(torch.empty([self.attention_dim, 2 * self.hidden_size], **factory_kwargs)) + + if self.qkv_bias: + self.q_bias = nn.Parameter(torch.empty([self.hidden_size], **factory_kwargs)) + self.kv_bias = nn.Parameter(torch.empty([2 * self.hidden_size], **factory_kwargs)) + + + def forward(self, hidden_states, encoder_hidden_states=None): + + if self.cross_attention_dim is None: + if not self.qkv_bias: + qkv = torch.matmul(hidden_states, self.weight) + else: + qkv = torch.addmm( + self.bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.weight, + beta=1, + alpha=1 + ) + + batch, seqlen, _ = hidden_states.shape + qkv_shape = (batch, seqlen, 3, -1) + qkv = qkv.view(qkv_shape) + q, k, v = qkv.unbind(2) + + else: + if not self.qkv_bias: + q = torch.matmul(hidden_states, self.q_weight) + kv = torch.matmul(encoder_hidden_states, self.kv_weight) + else: + q = torch.addmm( + self.q_bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.q_weight, + beta=1, + alpha=1 + ) + kv = torch.addmm( + self.kv_bias, + encoder_hidden_states.view( + encoder_hidden_states.size(0) * encoder_hidden_states.size(1), + encoder_hidden_states.size(2)), + self.kv_weight, + beta=1, + alpha=1 + ) + + batch, seqlen, _ = encoder_hidden_states.shape + kv_shape = (batch, seqlen, 2, -1) + + kv = kv.view(kv_shape) + k, v = kv.unbind(2) + + batch, seqlen, _ = hidden_states.shape + q = q.view(batch, seqlen, -1) + + return q, k, v \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/mlp.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/mlp.py new file mode 100644 index 0000000000..faee15c282 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/mlp.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections.abc +from itertools import repeat + +import torch.nn as nn +from .activation import get_activation_fn + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + + def __init__( + self, + features_in, + act_layer, + features_hidden=None, + features_out=None, + norm_layer=None, + bias=True, + ): + super().__init__() + features_out = features_out or features_in + features_hidden = features_hidden or features_in + to_2tuple = self._ntuple(2) + bias = to_2tuple(bias) + linear_layer = nn.Linear + + self.fc1 = linear_layer(features_in, features_hidden, bias=bias[0]) + self.act = get_activation_fn(act_layer) + self.norm = norm_layer(features_hidden) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(features_hidden, features_out, bias=bias[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.fc2(x) + return x + + def _ntuple(self, n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/norm.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/norm.py new file mode 100644 index 0000000000..97e2bf6ea4 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/norm.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def get_normalization_helper(norm_type: str, norm_dim: int, eps: float = 1e-5): + match norm_type: + case None: + return nn.Identity() + case 'layer_norm': + return nn.LayerNorm(norm_dim, eps=eps) + case _: + raise ValueError(f"Unsupported norm_type:{norm_type}.") + + +class AdaLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + def forward(self, x, shift, scale): + return F.layer_norm(x, normalized_shape=[self.hidden_size], weight=scale, bias=shift, eps=self.eps) + + +def Normalize(in_channels, num_groups=32, norm_type="groupnorm"): + if norm_type == "groupnorm": + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + elif norm_type == "layernorm": + return VideoLayerNorm(num_channels=in_channels, eps=1e-6) + + +class VideoLayerNorm(nn.Module): + def __init__(self, num_channels, eps=1e-6, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.norm = torch.nn.LayerNorm(num_channels, eps=eps, elementwise_affine=True) + def forward(self, x): + if x.dim() == 5: + x = rearrange(x, "b c t h w -> b t h w c") + x = self.norm(x) + x = rearrange(x, "b t h w c -> b c t h w") + else: + x = rearrange(x, "b c h w -> b h w c") + x = self.norm(x) + x = rearrange(x, "b h w c -> b c h w") + return x diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/sampling.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/sampling.py new file mode 100644 index 0000000000..be4b535be6 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/sampling.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, Tuple +from collections import deque + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_npu + +from .utils import video_to_image +from .conv import PlanCausalConv3d + + +class Upsample(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.with_conv = True + if self.with_conv: + self.conv = nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + @video_to_image + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, out_channels, undown=False): + super().__init__() + self.with_conv = True + self.undown = undown + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + if self.undown: + self.conv = nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.conv = nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=0) + + @video_to_image + def forward(self, x): + if self.with_conv: + if self.undown: + x = self.conv(x) + else: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class Spatial2xTime2x3DDownsample(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = PlanCausalConv3d(in_channels, out_channels, kernel_size=3, padding=0, stride=2) + + def forward(self, x): + pad = (0, 1, 0, 1, 0, 0) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Spatial2xTime2x3DUpsample(nn.Module): + def __init__( + self, + in_channels, + out_channels, + t_interpolation="trilinear", + enable_cached=False, + ): + super().__init__() + self.t_interpolation = t_interpolation + self.conv = PlanCausalConv3d(in_channels, out_channels, kernel_size=3, padding=1) + self.enable_cached = enable_cached + self.causal_cached = deque() + + def forward(self, x): + mode_method = "trilinear" + if x.size(2) > 1 or self.causal_cached is not None : + if self.enable_cached and len(self.causal_cached) > 0: + x = torch.cat([self.causal_cached.popleft(), x], dim=2) + self.causal_cached.append(x[:, :, -2:-1].clone()) + x = F.interpolate(x, scale_factor=(2, 1, 1), mode=self.t_interpolation) + x = x[:, :, 2:] + x = F.interpolate(x, scale_factor=(1, 2, 2), mode=mode_method) + else: + if self.enable_cached: + self.causal_cached.append(x[:, :, -1:].clone()) + x, x_ = x[:, :, :1], x[:, :, 1:] + x_ = F.interpolate( + x_, scale_factor=(2, 1, 1), mode=self.t_interpolation + ) + x_ = F.interpolate(x_, scale_factor=(1, 2, 2), mode=mode_method) + x = F.interpolate(x, scale_factor=(1, 2, 2), mode=mode_method) + x = torch.concat([x, x_], dim=2) + else: + if self.enable_cached: + self.causal_cached.append(x[:, :, -1:].clone()) + x = F.interpolate(x, scale_factor=(1, 2, 2), mode=mode_method) + return self.conv(x) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/utils.py new file mode 100644 index 0000000000..7423287696 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/utils.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +from einops import rearrange + + +def rearrange_flatten_t(x): + x_shape = x.shape + x = x.transpose(1, 2) + return x.view((x_shape[0] * x_shape[2]), x_shape[1], x_shape[3], x_shape[4]) + + +def rearrange_unflatten_t(x, b): + x_shape = x.shape + x = x.view(b, x_shape[0] // b, x_shape[1], x_shape[2], x_shape[3]) + return x.transpose(1, 2) + + +def video_to_image(func): + def wrapper(self, x, *args, **kwargs): + if x.dim() == 5: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = func(self, x, *args, **kwargs) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + return wrapper + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length) + + +MODULES_BASE = "open_sora_planv1_3.layers." + + +def resolve_str_to_obj(str_val, append=True): + if append: + str_val = MODULES_BASE + str_val + module_name, class_name = str_val.rsplit('.', 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/vresnet.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/vresnet.py new file mode 100644 index 0000000000..c75dc392dd --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/layers/vresnet.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from .norm import Normalize +from .conv import PlanCausalConv3d +from .utils import video_to_image + +from .activation import get_activation_fn + + +class VideoResnetBlock2D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + norm_type, + **kwargs + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type=norm_type) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = Normalize(out_channels, norm_type=norm_type) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.nonlinearity = get_activation_fn("silu") + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + @video_to_image + def forward(self, x): + h = x + #CAST ? + h = self.norm1(h) + + h = self.nonlinearity(h) + h = self.conv1(h) + #CAST ? + h = self.norm2(h) + h = self.nonlinearity(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + x = x + h + return x + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + norm_type, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type=norm_type) + self.conv1 = PlanCausalConv3d(in_channels, out_channels, 3, padding=1) + self.norm2 = Normalize(out_channels, norm_type=norm_type) + self.conv2 = PlanCausalConv3d(out_channels, out_channels, 3, padding=1) + self.nonlinearity = get_activation_fn("silu") + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = PlanCausalConv3d( + in_channels, out_channels, 3, padding=1 + ) + else: + self.nin_shortcut = PlanCausalConv3d( + in_channels, out_channels, 1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = self.nonlinearity(h) + h = self.conv1(h) + #CAST float32 ? + h = self.norm2(h) + + h = self.nonlinearity(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/comm.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/comm.py new file mode 100644 index 0000000000..5943e5b325 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/comm.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist + +from .parallel_mgr import get_sequence_parallel_size, get_sequence_parallel_group + + +def _all_to_all_func(input_, world_size, process_group, scatter_dim=2, gather_dim=1): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=process_group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +def split_sequence(input_, process_group: dist.ProcessGroup, dim: int, pad: int): + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + if world_size == 1: + return input_ + + if pad > 0: + pad_size = list(input_.shape) + pad_size[dim] = pad + input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim) + + dim_size = input_.size(dim) + if dim_size % world_size != 0: + raise ValueError( + f"The th{dim} dimensions of input_:{input_.size()} is not divisible by world_size:{world_size}.") + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + output = tensor_list[rank].contiguous() + return output + + +def gather_sequence(input_, process_group: dist.ProcessGroup, dim: int, pad: int): + input_ = input_.contiguous() + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + #all gather + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, input_, group=process_group) + + #concat + output = torch.cat(tensor_list, dim=dim) + + if pad > 0: + output = output.narrow(dim, 0, output.size(dim) - pad) + + return output + +# ====== +# Pad +# ====== + +SPTIAL_PAD = 0 +TEMPORAL_PAD = 0 + + +def set_spatial_pad(dim_size: int): + sp_size = get_sequence_parallel_size() + pad = (sp_size - (dim_size % sp_size)) % sp_size + global SPTIAL_PAD + SPTIAL_PAD = pad + + +def get_spatial_pad() -> int: + return SPTIAL_PAD + + +def set_temporal_pad(dim_size: int): + sp_size = get_sequence_parallel_size() + pad = (sp_size - (dim_size % sp_size)) % sp_size + global TEMPORAL_PAD + TEMPORAL_PAD = pad + + +def get_temporal_pad() -> int: + return TEMPORAL_PAD + + +def all_to_all_with_pad( + input_: torch.Tensor, + process_group: dist.ProcessGroup, + **kwargs +): + scatter_dim = kwargs.get("scatter_dim", 2) + gather_dim = kwargs.get("gather_dim", 1) + scatter_pad = kwargs.get("scatter_pad", 0) + gather_pad = kwargs.get("gather_pad", 0) + + if scatter_pad > 0: + pad_shape = list(input_.shape) + pad_shape[scatter_dim] = scatter_pad + pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype) + input_ = torch.cat([input_, pad_tensor], dim=scatter_dim) + + world_size = dist.get_world_size(process_group) + if input_.shape[scatter_dim] % world_size != 0: + raise ValueError( + f"The scatter_dim:{scatter_dim} of input_:{input_.shape} is not divisible by world_size:{world_size}.") + + input_ = _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim) + + if gather_pad > 0: + input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) + + return input_ + + +def all_to_all( + tensor: torch.Tensor, + world_size: int, + scatter_dim: int, + gather_dim: int, + process_group: dist.ProcessGroup = None, +): + if process_group is None: + process_group = dist.group.WORLD + return _all_to_all_func(tensor, world_size, process_group, scatter_dim, gather_dim) + + +def all_to_all_sbh( + input_: torch.Tensor, + scatter_dim: int = 1, + gather_dim: int = 0, +): + return single_all_to_all(input_, scatter_dim, gather_dim) + + +def single_all_to_all( + input_: torch.Tensor, + scatter_dim: int, + gather_dim: int, +): + + sp_size = get_sequence_parallel_size() + inp_shape = list(input_.shape) + inp_shape[scatter_dim] = inp_shape[scatter_dim] // sp_size + if scatter_dim < 1: + input_t = input_.reshape( + [sp_size, inp_shape[scatter_dim]] + \ + inp_shape[scatter_dim + 1:] + ) + else: + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + input_t = input_.reshape( + [-1, sp_size, inp_shape[scatter_dim]] + \ + inp_shape[scatter_dim + 1:] + ).transpose(0, 1).contiguous() + + output = torch.empty_like(input_t) + + dist.all_to_all_single(output, input_t, group=get_sequence_parallel_group()) + # if scattering the seq-dim, transpose the heads back to the original dimension + if scatter_dim < 1: + output = output.transpose(0, 1).contiguous() + + return output.reshape( + inp_shape[: gather_dim] + [inp_shape[gather_dim] * sp_size, ] + inp_shape[gather_dim + 1:]) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/model_utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/model_utils.py new file mode 100644 index 0000000000..973dfe3149 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/model_utils.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch_npu + + +def weight_switch(weights, prefix_key, new_weight, old_weight, transpose=None): + if prefix_key + old_weight in weights: + weights[prefix_key + new_weight] = weights.pop(prefix_key + old_weight) + if transpose: + weights[prefix_key + new_weight] = weights[prefix_key + new_weight + ].transpose(*transpose).contiguous() + + +def get_attn_weight(weights, prefix_key, cross_attention, fuse=True): + cache_weights = {} + # If self attention, fuse the qkv, crcoss attention fuse the kv + qkv = ["q", "k", "v"] if not cross_attention else ["k", "v"] + for wb in ["weight", "bias"]: + if fuse: + weight_name = [prefix_key + f'to_{i}.' + wb for i in qkv] + conds = [w in weights for w in weight_name] + # If weights do not contain all the q,k,v, put them in the cache_weights + # And the cache_weights will be added in the next shard weights + if not all(conds) and any(conds): + for w in weight_name: + if w in weights: + cache_weights[w] = weights.pop(w) + # weights contain all the q k v weight + if all(conds): + qkv_weight = [] + for w in weight_name: + qkv_weight.append(weights.pop(w)) + mid_key = "".join(qkv) + "_" if cross_attention else "" + if wb == "weight": + weights[prefix_key + "qkv_proj." + mid_key + wb] = torch.cat( + qkv_weight, dim=0).transpose(-1, 0).contiguous() + else: + weights[prefix_key + "qkv_proj." + mid_key + wb] = torch.cat(qkv_weight, dim=0) + else: + for q in qkv: + weight_switch(weights, prefix_key, q + '_proj.' + wb, f'to_{q}.' + wb) + + if cross_attention: + weight_switch(weights, prefix_key, 'qkv_proj.q_' + wb, 'to_q.' + wb, + transpose=(-1, 0) if wb == 'weight' else None) + + # switch out linear + weight_switch(weights, prefix_key, 'out_proj.' + wb, 'to_out.0.' + wb) + return cache_weights \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/parallel_mgr.py new file mode 100644 index 0000000000..23a4250942 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/parallel_mgr.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import torch_npu +import torch.distributed as dist + + +class ParallelManager(): + def __init__(self, world_size=1, rank=0, group=None): + self.sp_size = world_size + self.sp_group = group + self.enable_sp = world_size > 1 + self.rank = rank + + +PARALLEL_MANAGER = ParallelManager() + + +def set_parallel_manager(world_size, rank, group): + global PARALLEL_MANAGER + PARALLEL_MANAGER = ParallelManager(world_size, rank, group) + + +def get_sequence_parallel_group(): + return PARALLEL_MANAGER.sp_group + + +def get_sequence_parallel_size(): + return PARALLEL_MANAGER.sp_size + + +def get_sequence_parallel_state(): + return PARALLEL_MANAGER.enable_sp + + +def get_sequence_parallel_rank(): + return PARALLEL_MANAGER.rank + + +def init_parallel_env(enable_sequence_parallelism): + rank = int(os.getenv('RANK', 0)) + world_size = int(os.getenv('WORLD_SIZE', 1)) + torch_npu.npu.set_device(rank) + dist.init_process_group( + backend='hccl', init_method='env://', + world_size=world_size, rank=rank + ) + if enable_sequence_parallelism: + set_parallel_manager(world_size, rank, dist.group.WORLD) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/t2vdit.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/t2vdit.py new file mode 100644 index 0000000000..afaaf82222 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/t2vdit.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any, Dict, Optional, Tuple +import inspect + +import torch +import torch_npu +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from mindiesd import ConfigMixin +from mindiesd import DiffusionModel + +from ..layers import Mlp, AdaLayerNorm +from ..layers.embedding import PatchEmbed2D +from ..layers.embedding import AdaLayerNormSingle +from ..layers.attention import ReconstitutionAttention, OpenSoraPlanAttnProcessor +from .model_utils import get_attn_weight, weight_switch + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + interpolation_scale_thw: Tuple[int] = (1, 1, 1), + sparse1d: bool = False, + sparse_n: int = 2, + sparse_group: bool = False, + ): + super().__init__() + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = AdaLayerNorm(dim, norm_eps) + + processor = OpenSoraPlanAttnProcessor( + interpolation_scale_thw=interpolation_scale_thw, sparse1d=sparse1d, sparse_n=sparse_n, + sparse_group=sparse_group, is_cross_attn=False + ) + self.attn1 = ReconstitutionAttention( + attention_dim=dim, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + qkv_bias=attention_bias, + out_proj_bias=attention_out_bias, + processor=processor + ) # is self-attn if encoder_hidden_states is none + + # 2. Cross-Attn + self.norm2 = AdaLayerNorm(dim, norm_eps) + + processor = OpenSoraPlanAttnProcessor( + interpolation_scale_thw=interpolation_scale_thw, sparse1d=sparse1d, sparse_n=sparse_n, + sparse_group=sparse_group, is_cross_attn=True + ) + self.attn2 = ReconstitutionAttention( + attention_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + qkv_bias=attention_bias, + out_proj_bias=attention_out_bias, + processor=processor + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + ff_inner_dim = ff_inner_dim or 4 * dim + self.ff = Mlp(features_in=dim, features_hidden=ff_inner_dim, + act_layer=activation_fn, bias=ff_bias) + + # 4. Scale-shift. + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + frame: int = None, + height: int = None, + width: int = None, + ) -> torch.FloatTensor: + + # 0. Self-Attention + batch_size = hidden_states.shape[1] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1) + ).chunk(6, dim=0) + + norm_hidden_states = self.norm1(hidden_states, shift_msa, (1 + scale_msa)) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, frame=frame, height=height, width=width, + ) + + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + norm_hidden_states = hidden_states + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, frame=frame, height=height, width=width, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm2(hidden_states, shift_mlp, (1 + scale_mlp)) + + ff_output = self.ff(norm_hidden_states) + + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + return hidden_states + + +class OpenSoraT2Vv1_3Config(ConfigMixin): + config_name = 'config.json' + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = True, + sample_size_h: Optional[int] = None, + sample_size_w: Optional[int] = None, + sample_size_t: Optional[int] = None, + patch_size: Optional[int] = None, + patch_size_t: Optional[int] = None, + activation_fn: str = "geglu", + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = None, + interpolation_scale_h: float = 1.0, + interpolation_scale_w: float = 1.0, + interpolation_scale_t: float = 1.0, + sparse1d: bool = False, + sparse_n: int = 2, + ): + self._init(locals()) + + def _init(self, value): + init_signature = inspect.signature(self.__init__) + parameters = init_signature.parameters + for param_name, _ in parameters.items(): + if param_name != 'self': + setattr(self, param_name, value[param_name]) + + +class OpenSoraT2Vv1_3(DiffusionModel): + config_class = OpenSoraT2Vv1_3Config + weigths_name = "diffusion_pytorch_model.safetensors.index.json" + + def __init__(self, config): + super().__init__(config) + # Set some common variables used across the board. + self.out_channels = config.in_channels if config.out_channels is None else config.out_channels + self.config.hidden_size = self.config.num_attention_heads * self.config.attention_head_dim + self.use_cache = False + self.cache = None + self._prepare_patched_inputs() + + def load_weights(self, state_dict): + with torch.no_grad(): + weights = state_dict + # attention_block: + cache_weights = {} + for i in range(self.config.num_layers): + + prefix_key = 'transformer_blocks.' + str(i) + '.' + cache_weights1 = get_attn_weight(weights, prefix_key + "attn1.", cross_attention=False) + cache_weights2 = get_attn_weight(weights, prefix_key + "attn2.", cross_attention=True) + cache_weights.update(cache_weights1) + cache_weights.update(cache_weights2) + + prefix_key = prefix_key + 'ff.' + weight_switch(weights, prefix_key, 'fc1.weight', 'net.0.proj.weight') + weight_switch(weights, prefix_key, 'fc1.bias', 'net.0.proj.bias') + weight_switch(weights, prefix_key, 'fc2.weight', 'net.2.weight') + weight_switch(weights, prefix_key, 'fc2.bias', 'net.2.bias') + + prefix_key = "caption_projection." + weight_switch(weights, prefix_key, 'fc1.weight', 'linear_1.weight') + weight_switch(weights, prefix_key, 'fc1.bias', 'linear_1.bias') + weight_switch(weights, prefix_key, 'fc2.weight', 'linear_2.weight') + weight_switch(weights, prefix_key, 'fc2.bias', 'linear_2.bias') + + self.load_state_dict(state_dict, strict=False) + return state_dict.keys(), cache_weights + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + step_id: int = 0, + **kwargs, + ): + batch_size, _, frame, h, w = hidden_states.shape + attention_mask, encoder_attention_mask = self._standard_mask(attention_mask, encoder_attention_mask) + # 1. Input + frame = ((frame - 1) // self.config.patch_size_t + 1 + ) if frame % 2 == 1 else frame // self.config.patch_size_t # patchfy + height = hidden_states.shape[-2] // self.config.patch_size + width = hidden_states.shape[-1] // self.config.patch_size + + hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, batch_size, frame) + + # x (t*h*w b d) or (t//sp*h*w b d) + # cond_1 (l b d) or (l//sp b d) + hidden_states = rearrange(hidden_states, 'b s h -> s b h', b=batch_size).contiguous() + encoder_hidden_states = rearrange(encoder_hidden_states, 'b s h -> s b h', b=batch_size).contiguous() + timestep = timestep.view(batch_size, 6, -1).transpose(0, 1).contiguous() + + sparse_mask = {} + for sparse_n in [1, 4]: + sparse_mask[sparse_n] = prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n) + + # 2. Blocks + for i, block in enumerate(self.transformer_blocks): + if i > 1 and i < 30: + mask_group = sparse_mask.get(block.attn1.processor.sparse_n, None) + attention_mask, encoder_attention_mask = mask_group.get(block.attn1.processor.sparse_group, None) + else: + mask_group = sparse_mask.get(1, None) + attention_mask, encoder_attention_mask = mask_group.get(block.attn1.processor.sparse_group, None) + if self.use_cache: + hidden_states = self.cache(block, step_id, i, hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, frame=frame, height=height, width=width, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, frame=frame, height=height, width=width, + ) + # New shape (b, t*h*w, h) or (b, t//sp*h*w, h) + hidden_states = rearrange(hidden_states, 's b h -> b s h', b=batch_size).contiguous() + + # 3. Output + video_size = (frame, height, width) + output = self._get_output_for_patched_inputs( + hidden_states=hidden_states, + embedded_timestep=embedded_timestep, + video_size=video_size + ) # b c t h w + return (output,) + + def _prepare_patched_inputs(self): + self.config.sample_size = (self.config.sample_size_h, self.config.sample_size_w) + interpolation_scale_thw = ( + self.config.interpolation_scale_t, + self.config.interpolation_scale_h, + self.config.interpolation_scale_w + ) + + self.caption_projection = Mlp( + features_in=self.config.caption_channels, + features_hidden=self.config.hidden_size, + features_out=self.config.hidden_size, + act_layer="gelu-approximate" + ) + self.pos_embed = PatchEmbed2D( + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.config.hidden_size, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.config.hidden_size, + self.config.num_attention_heads, + self.config.attention_head_dim, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + interpolation_scale_thw=interpolation_scale_thw, + sparse1d=self.config.sparse1d if i > 1 and i < 30 else False, + sparse_n=self.config.sparse_n, + sparse_group=i % 2 == 1, + ) + for i in range(self.config.num_layers) + ] + ) + self.norm_out = nn.LayerNorm(self.config.hidden_size, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.config.hidden_size) / self.config.hidden_size**0.5) + self.proj_out = nn.Linear( + self.config.hidden_size, + self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels + ) + self.adaln_single = AdaLayerNormSingle(self.config.hidden_size) + + def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame): + + hidden_states = self.pos_embed(hidden_states.to(self.dtype)) + + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype + ) # b 6d, b d + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1, l, d or b, 1, l, d + encoder_hidden_states = rearrange(encoder_hidden_states, 'b 1 l d -> (b 1) l d') + + return hidden_states, encoder_hidden_states, timestep, embedded_timestep + + def _get_output_for_patched_inputs( + self, hidden_states, embedded_timestep, video_size + ): + (num_frames, height, width) = video_size + + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + hidden_states = hidden_states.reshape( + shape=(-1, num_frames, height, width, + self.config.patch_size_t, self.config.patch_size, self.config.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, num_frames * self.config.patch_size_t, + height * self.config.patch_size, width * self.config.patch_size) + ) + return output + + def _standard_mask(self, attention_mask, encoder_attention_mask): + if attention_mask is not None and attention_mask.ndim == 4: + + attention_mask = attention_mask.to(self.dtype) + + attention_mask = attention_mask.unsqueeze(1) # b 1 t h w + attention_mask = F.max_pool3d( + attention_mask, + kernel_size=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size), + stride=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size) + ) + attention_mask = rearrange(attention_mask, 'b 1 t h w -> (b 1) 1 (t h w)') + attention_mask = (1 - attention_mask.bool().to(self.dtype)) * -10000.0 + + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: + # b, 1, l + encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 + return attention_mask, encoder_attention_mask + + +def prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n): + attention_mask = attention_mask.unsqueeze(1) + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + seqlen = attention_mask.shape[-1] + if seqlen % (sparse_n * sparse_n) == 0: + pad_len = 0 + else: + pad_len = sparse_n * sparse_n - seqlen % (sparse_n * sparse_n) + if pad_len != 0: + attention_mask_sparse = F.pad(attention_mask, (0, pad_len, 0, 0), value=-9980.0) + seqlen = attention_mask_sparse.shape[-1] + attention_mask_sparse_1d = rearrange( + attention_mask_sparse, + 'b 1 1 (g k) -> (k b) 1 1 g', + k=sparse_n + ) + attention_mask_sparse_1d_group = rearrange( + attention_mask_sparse, + 'b 1 1 (n m k) -> (m b) 1 1 (n k)', + m=sparse_n, + k=sparse_n + ) + + encoder_attention_mask_sparse = encoder_attention_mask.repeat(sparse_n, 1, 1, 1) + + encoder_attention_mask_sparse_1d = get_attention_mask( + encoder_attention_mask_sparse, int(seqlen / sparse_n) + ) + encoder_attention_mask_sparse_1d_group = encoder_attention_mask_sparse_1d + if pad_len != 0: + attention_mask_sparse_1d = get_attention_mask( + attention_mask_sparse_1d, attention_mask_sparse_1d.shape[-1]) + attention_mask_sparse_1d_group = get_attention_mask( + attention_mask_sparse_1d_group, attention_mask_sparse_1d_group.shape[-1]) + else: + attention_mask_sparse_1d = None + attention_mask_sparse_1d_group = None + + return { + False: (attention_mask_sparse_1d, encoder_attention_mask_sparse_1d), + True: (attention_mask_sparse_1d_group, encoder_attention_mask_sparse_1d_group) + } + + +def get_attention_mask(attention_mask, repeat_num): + attention_mask = attention_mask.to(torch.bool) + attention_mask = attention_mask.repeat_interleave(repeat_num, dim=-2) + return attention_mask \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/wfvae.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/wfvae.py new file mode 100644 index 0000000000..4f9e305b6d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/models/wfvae.py @@ -0,0 +1,576 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from copy import deepcopy +from typing import List +from collections import deque +import inspect + +import torch +import torch.nn as nn +from einops import rearrange +from mindiesd import ConfigMixin + +from .model_utils import DiffusionModel +from ..layers.vresnet import ResnetBlock3D, VideoResnetBlock2D +from ..layers.norm import Normalize +from ..layers.conv import VideoConv2d, PlanCausalConv3d +from ..layers.utils import resolve_str_to_obj +from ..layers.wavelet import HaarWaveletTransform3D, InverseHaarWaveletTransform3D +from ..layers.activation import get_activation_fn + + +class DiagonalGaussianDistribution(object): + def __init__( + self, + parameters, + deterministic=False, + ): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device, dtype=self.mean.dtype) + + def sample(self): + # torch.randn: standard normal distribution + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device, dtype=self.mean.dtype) + return x + + def mode(self): + return self.mean + + +class Encoder(nn.Module): + + def __init__( + self, + latent_dim: int = 8, + base_channels: int = 128, + num_resblocks: int = 2, + energy_flow_hidden_size: int = 64, + attention_type: str = "AttnBlock3DFix", + use_attention: bool = True, + norm_type: str = "groupnorm", + l1_dowmsample_block: str = "Downsample", + l1_downsample_wavelet: str = "HaarWaveletTransform2D", + l2_dowmsample_block: str = "Spatial2xTime2x3DDownsample", + l2_downsample_wavelet: str = "HaarWaveletTransform3D", + ) -> None: + super().__init__() + self.down1 = self._init_down1(base_channels, norm_type, num_resblocks, l1_dowmsample_block) + self.energy_flow_hidden_size = energy_flow_hidden_size + + self.down2 = self._init_down2(base_channels, norm_type, num_resblocks, l2_dowmsample_block) + + # Connection + if l1_dowmsample_block == "Downsample": # Bad code. For temporal usage. + l1_channels = 12 + else: + l1_channels = 24 + + self.connect_l1 = VideoConv2d( + l1_channels, energy_flow_hidden_size, kernel_size=3, stride=1, padding=1 + ) + self.connect_l2 = VideoConv2d( + 24, energy_flow_hidden_size, kernel_size=3, stride=1, padding=1 + ) + # Mid + mid_layers = [ + ResnetBlock3D( + in_channels=base_channels * 2 + energy_flow_hidden_size, + out_channels=base_channels * 4, + norm_type=norm_type, + ), + ResnetBlock3D( + in_channels=base_channels * 4, + out_channels=base_channels * 4, + norm_type=norm_type, + ), + ] + if use_attention: + mid_layers.insert( + 1, resolve_str_to_obj(attention_type)(in_channels=base_channels * 4, norm_type=norm_type) + ) + self.mid = nn.Sequential(*mid_layers) + self.norm_out = Normalize(base_channels * 4, norm_type=norm_type) + self.conv_out = PlanCausalConv3d( + base_channels * 4, latent_dim * 2, kernel_size=3, stride=1, padding=1 + ) + self.wavelet_transform_in = HaarWaveletTransform3D() + self.wavelet_transform_l1 = resolve_str_to_obj(l1_downsample_wavelet)() + self.wavelet_transform_l2 = resolve_str_to_obj(l2_downsample_wavelet)() + self.nonlinearity = get_activation_fn("silu") + + def forward(self, x): + coeffs = self.wavelet_transform_in(x) + l1_coeffs = coeffs[:, :3] + l1_coeffs = self.wavelet_transform_l1(l1_coeffs) + l1 = self.connect_l1(l1_coeffs) + l2_coeffs = self.wavelet_transform_l2(l1_coeffs[:, :3]) + l2 = self.connect_l2(l2_coeffs) + + h = self.down1(coeffs) + h = torch.concat([h, l1], dim=1) + h = self.down2(h) + h = torch.concat([h, l2], dim=1) + h = self.mid(h) + h = self.norm_out(h) + h = self.nonlinearity(h) + h = self.conv_out(h) + return h, (l1_coeffs, l2_coeffs) + + def _init_down1(self, base_channels, norm_type, num_resblocks, l1_dowmsample_block): + block = nn.Sequential( + VideoConv2d(24, base_channels, kernel_size=3, stride=1, padding=1), + *[ + VideoResnetBlock2D( + in_channels=base_channels, + out_channels=base_channels, + norm_type=norm_type, + ) + for _ in range(num_resblocks) + ], + resolve_str_to_obj(l1_dowmsample_block)(in_channels=base_channels, out_channels=base_channels), + ) + return block + + def _init_down2(self, base_channels, norm_type, num_resblocks, l2_dowmsample_block): + energy_flow_hidden_size = self.energy_flow_hidden_size + block = nn.Sequential( + VideoConv2d( + base_channels + energy_flow_hidden_size, + base_channels * 2, + kernel_size=3, + stride=1, + padding=1, + ), + *[ + ResnetBlock3D( + in_channels=base_channels * 2, + out_channels=base_channels * 2, + norm_type=norm_type, + ) + for _ in range(num_resblocks) + ], + resolve_str_to_obj(l2_dowmsample_block)(base_channels * 2, base_channels * 2), + ) + return block + + +class Decoder(nn.Module): + + def __init__( + self, + latent_dim: int = 8, + base_channels: int = 128, + num_resblocks: int = 2, + energy_flow_hidden_size: int = 128, + attention_type: str = "AttnBlock3DFix", + use_attention: bool = True, + norm_type: str = "groupnorm", + t_interpolation: str = "nearest", + connect_res_layer_num: int = 1, + l1_upsample_block: str = "Upsample", + l1_upsample_wavelet: str = "InverseHaarWaveletTransform2D", + l2_upsample_block: str = "Spatial2xTime2x3DUpsample", + l2_upsample_wavelet: str = "InverseHaarWaveletTransform3D", + ) -> None: + super().__init__() + self.energy_flow_hidden_size = energy_flow_hidden_size + self.norm_type = norm_type + self.conv_in = PlanCausalConv3d( + latent_dim, base_channels * 4, kernel_size=3, stride=1, padding=1) + + self.mid = self._init_mid(base_channels, attention_type, energy_flow_hidden_size, use_attention) + + self.up2 = self._init_up2(base_channels, num_resblocks, t_interpolation, unsample_type=l2_upsample_block) + self.up1 = self._init_up1(base_channels, num_resblocks, unsample_type=l1_upsample_block) + self.layer = nn.Sequential( + *[ + ResnetBlock3D( + in_channels=base_channels * (2 if i == 0 else 1), + out_channels=base_channels, + norm_type=norm_type, + ) + for i in range(2) + ], + ) + # Connection + if l1_upsample_block == "Upsample": # Bad code. For temporal usage. + l1_channels = 12 + else: + l1_channels = 24 + + self.connect_l1 = self._init_connect(energy_flow_hidden_size, connect_res_layer_num, l1_channels) + self.connect_l2 = self._init_connect(energy_flow_hidden_size, connect_res_layer_num, 24) + + # Out + self.norm_out = Normalize(base_channels, norm_type=norm_type) + self.conv_out = VideoConv2d(base_channels, 24, kernel_size=3, stride=1, padding=1) + + self.inverse_wavelet_transform_out = InverseHaarWaveletTransform3D() + self.inverse_wavelet_transform_l1 = resolve_str_to_obj(l1_upsample_wavelet)() + self.inverse_wavelet_transform_l2 = resolve_str_to_obj(l2_upsample_wavelet)() + self.nonlinearity = get_activation_fn("silu") + + def forward(self, z): + h = self.conv_in(z) + h = self.mid(h) + l2_coeffs = self.connect_l2(h[:, -self.energy_flow_hidden_size :]) + l2 = self.inverse_wavelet_transform_l2(l2_coeffs) + h = self.up2(h[:, : -self.energy_flow_hidden_size]) + l1_coeffs = h[:, -self.energy_flow_hidden_size :] + l1_coeffs = self.connect_l1(l1_coeffs) + l1_coeffs[:, :3] = l1_coeffs[:, :3] + l2 + l1 = self.inverse_wavelet_transform_l1(l1_coeffs) + + h = self.up1(h[:, : -self.energy_flow_hidden_size]) + + h = self.layer(h) + h = self.norm_out(h) + h = self.nonlinearity(h) + h = self.conv_out(h) + h[:, :3] = h[:, :3] + l1 + dec = self.inverse_wavelet_transform_out(h) + return dec, (l1_coeffs, l2_coeffs) + + def _init_mid(self, base_channels, attention_type, energy_flow_hidden_size, use_attention): + norm_type = self.norm_type + mid_layers = [ + ResnetBlock3D( + in_channels=base_channels * 4, + out_channels=base_channels * 4, + norm_type=norm_type, + ), + ResnetBlock3D( + in_channels=base_channels * 4, + out_channels=base_channels * 4 + energy_flow_hidden_size, + norm_type=norm_type, + ), + ] + if use_attention: + mid_layers.insert( + 1, resolve_str_to_obj(attention_type)(in_channels=base_channels * 4, norm_type=norm_type) + ) + return nn.Sequential(*mid_layers) + + def _init_up2(self, base_channels, num_resblocks, t_interpolation, unsample_type): + norm_type = self.norm_type + up_block = nn.Sequential( + *[ + ResnetBlock3D( + in_channels=base_channels * 4, + out_channels=base_channels * 4, + norm_type=norm_type, + ) + for _ in range(num_resblocks) + ], + resolve_str_to_obj(unsample_type)( + base_channels * 4, base_channels * 4, t_interpolation=t_interpolation + ), + ResnetBlock3D( + in_channels=base_channels * 4, + out_channels=base_channels * 4 + self.energy_flow_hidden_size, + norm_type=norm_type, + ), + ) + return up_block + + def _init_up1(self, base_channels, num_resblocks, unsample_type): + norm_type = self.norm_type + up_block = nn.Sequential( + *[ + ResnetBlock3D( + in_channels=base_channels * (4 if i == 0 else 2), + out_channels=base_channels * 2, + norm_type=norm_type, + ) + for i in range(num_resblocks) + ], + resolve_str_to_obj(unsample_type)(in_channels=base_channels * 2, out_channels=base_channels * 2), + ResnetBlock3D( + in_channels=base_channels * 2, + out_channels=base_channels * 2, + norm_type=norm_type, + ), + ) + return up_block + + def _init_connect(self, energy_flow_hidden_size, connect_res_layer_num, conv_channel): + norm_type = self.norm_type + connect = nn.Sequential( + *[ + ResnetBlock3D( + in_channels=energy_flow_hidden_size, + out_channels=energy_flow_hidden_size, + norm_type=norm_type, + ) + for _ in range(connect_res_layer_num) + ], + VideoConv2d(energy_flow_hidden_size, conv_channel, kernel_size=3, stride=1, padding=1), + ) + return connect + + +class WFVAEModelConfig(ConfigMixin): + config_name = "config.json" + + def __init__( + self, + latent_dim: int = 8, + base_channels: int = 128, + encoder_num_resblocks: int = 2, + encoder_energy_flow_hidden_size: int = 64, + decoder_num_resblocks: int = 2, + decoder_energy_flow_hidden_size: int = 128, + attention_type: str = "AttnBlock3DFix", + use_attention: bool = True, + norm_type: str = "groupnorm", + t_interpolation: str = "nearest", + connect_res_layer_num: int = 1, + scale: List[float] = None, + shift: List[float] = None, + l1_dowmsample_block: str = "Downsample", + l1_downsample_wavelet: str = "HaarWaveletTransform2D", + l2_dowmsample_block: str = "Spatial2xTime2x3DDownsample", + l2_downsample_wavelet: str = "HaarWaveletTransform3D", + l1_upsample_block: str = "Upsample", + l1_upsample_wavelet: str = "InverseHaarWaveletTransform2D", + l2_upsample_block: str = "Spatial2xTime2x3DUpsample", + l2_upsample_wavelet: str = "InverseHaarWaveletTransform3D", + ): + self._init(locals()) + if not scale: + self.scale = [0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215] + if not shift: + self.shift = [0, 0, 0, 0, 0, 0, 0, 0] + + def _init(self, value): + init_signature = inspect.signature(self.__init__) + parameters = init_signature.parameters + for param_name, _ in parameters.items(): + if param_name != 'self': + setattr(self, param_name, value[param_name]) + + +class WFVAEModel(DiffusionModel): + config_class = WFVAEModelConfig + weigths_name = "merged.ckpt" + + def __init__( + self, + config + ) -> None: + super().__init__(config) + # Module config + + self.use_tiling = False + # Hardcode for now + self.t_chunk_enc = 8 + self.t_chunk_dec = 2 + self.t_upsample_times = 2 + + self.use_quant_layer = False + self.encoder = Encoder( + latent_dim=config.latent_dim, + base_channels=config.base_channels, + num_resblocks=config.encoder_num_resblocks, + energy_flow_hidden_size=config.encoder_energy_flow_hidden_size, + use_attention=config.use_attention, + norm_type=config.norm_type, + l1_dowmsample_block=config.l1_dowmsample_block, + l1_downsample_wavelet=config.l1_downsample_wavelet, + l2_dowmsample_block=config.l2_dowmsample_block, + l2_downsample_wavelet=config.l2_downsample_wavelet, + attention_type=config.attention_type + ) + self.decoder = Decoder( + latent_dim=config.latent_dim, + base_channels=config.base_channels, + num_resblocks=config.decoder_num_resblocks, + energy_flow_hidden_size=config.decoder_energy_flow_hidden_size, + use_attention=config.use_attention, + norm_type=config.norm_type, + t_interpolation=config.t_interpolation, + connect_res_layer_num=config.connect_res_layer_num, + l1_upsample_block=config.l1_upsample_block, + l1_upsample_wavelet=config.l1_upsample_wavelet, + l2_upsample_block=config.l2_upsample_block, + l2_upsample_wavelet=config.l2_upsample_wavelet + ) + + # Set cache offset for trilinear lossless upsample. + self._set_cache_offset([self.decoder.up2, self.decoder.connect_l2, self.decoder.conv_in, self.decoder.mid], 1) + self._set_cache_offset([ + self.decoder.up2[-2:], self.decoder.up1, self.decoder.connect_l1, self.decoder.layer], + self.t_upsample_times) + + def encode(self, x): + self._empty_causal_cached(self.encoder) + self._set_first_chunk(True) + + if self.use_tiling: + h = self._tile_encode(x) + l1, l2 = None, None + else: + h, (l1, l2) = self.encoder(x) + if self.use_quant_layer: + h = self.quant_conv(h) + + posterior = DiagonalGaussianDistribution(h) + return posterior + + def decode(self, z): + self._empty_causal_cached(self.decoder) + self._set_first_chunk(True) + + if self.use_tiling: + dec = self._tile_decode(z) + else: + if self.use_quant_layer: + z = self.post_quant_conv(z) + dec, _ = self.decoder(z) + + return dec + + def enable_tiling(self, use_tiling: bool = True): + self.use_tiling = use_tiling + self._set_causal_cached(use_tiling) + + def disable_tiling(self): + self.enable_tiling(False) + + def load_weights(self, state_dict): + with torch.no_grad(): + missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False) + + def _empty_causal_cached(self, parent): + for _, module in parent.named_modules(): + if hasattr(module, 'causal_cached'): + module.causal_cached = deque() + + def _set_causal_cached(self, enable_cached=True): + for _, module in self.named_modules(): + if hasattr(module, 'enable_cached'): + module.enable_cached = enable_cached + + def _set_cache_offset(self, modules, cache_offset=0): + for module in modules: + for submodule in module.modules(): + if hasattr(submodule, 'cache_offset'): + submodule.cache_offset = cache_offset + + def _set_first_chunk(self, is_first_chunk=True): + for module in self.modules(): + if hasattr(module, 'is_first_chunk'): + module.is_first_chunk = is_first_chunk + + def _build_chunk_start_end(self, t, decoder_mode=False): + start_end = [[0, 1]] + start = 1 + end = start + while True: + if start >= t: + break + end = min(t, end + (self.t_chunk_dec if decoder_mode else self.t_chunk_enc)) + start_end.append([start, end]) + start = end + return start_end + + def _tile_encode(self, x): + b, c, t, h, w = x.shape + + start_end = self._build_chunk_start_end(t) + result = [] + for idx, (start, end) in enumerate(start_end): + self._set_first_chunk(idx == 0) + chunk = x[:, :, start:end, :, :] + chunk = self.encoder(chunk)[0] + if self.use_quant_layer: + chunk = self.quant_conv(chunk) + result.append(chunk) + + return torch.cat(result, dim=2) + + def _tile_decode(self, x): + b, c, t, h, w = x.shape + + start_end = self._build_chunk_start_end(t, decoder_mode=True) + result = [] + for idx, (start, end) in enumerate(start_end): + self._set_first_chunk(idx == 0) + + if end + 1 < t: + chunk = x[:, :, start:end + 1, :, :] + else: + chunk = x[:, :, start:end, :, :] + + if self.use_quant_layer: + chunk = self.post_quant_conv(chunk) + chunk = self.decoder(chunk)[0] + if end + 1 < t: + chunk = chunk[:, :, :-4] + result.append(chunk.clone()) + else: + result.append(chunk.clone()) + + return torch.cat(result, dim=2) + + + + +class WFVAEModelWrapper(nn.Module): + def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs): + super(WFVAEModelWrapper, self).__init__() + self.vae = WFVAEModel.from_pretrained(model_path, **kwargs) + self.register_buffer('shift', torch.tensor(self.vae.config.shift)[None, :, None, None, None]) + self.register_buffer('scale', torch.tensor(self.vae.config.scale)[None, :, None, None, None]) + + @property + def dtype(self): + return self.vae.dtype + + @property + def device(self): + return self.vae.device + + @classmethod + def from_pretrained(cls, model_path, **kwargs): + return cls(model_path, **kwargs) + + def encode(self, x): + x = (self.vae.encode(x).sample() - self.shift.to(x.device, dtype=x.dtype)) * \ + self.scale.to(x.device, dtype=x.dtype) + return x + + def decode(self, x): + x = x / self.scale.to(x.device, dtype=x.dtype) + self.shift.to(x.device, dtype=x.dtype) + x = self.vae.decode(x) + x = rearrange(x, 'b c t h w -> b t c h w').contiguous() + return x + + +ae_stride_config = { + 'WFVAEModel_D8_4x8x8': [4, 8, 8], + 'WFVAEModel_D16_4x8x8': [4, 8, 8], + 'WFVAEModel_D32_4x8x8': [4, 8, 8], + 'WFVAEModel_D32_8x8x8': [8, 8, 8], +} \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/pipeline/open_soar_plan_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/pipeline/open_soar_plan_pipeline.py new file mode 100644 index 0000000000..75180201c1 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/pipeline/open_soar_plan_pipeline.py @@ -0,0 +1,463 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union +from dataclasses import dataclass + +import numpy as np +import torch +import torch_npu +from einops import rearrange +from tqdm import tqdm + + +from ..models.parallel_mgr import get_sequence_parallel_state, get_sequence_parallel_size, get_sequence_parallel_rank +from .pipeline_utils import DiffusionPipeline, rescale_noise_cfg, retrieve_timesteps + + +class OpenSoraPlanPipeline13(DiffusionPipeline): + + def __init__( + self, + vae, + text_encoder, + tokenizer, + transformer, + scheduler, + text_encoder_2=None, + tokenizer_2=None + ): + super().__init__() + + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.tokenizer_2 = tokenizer_2 + self.transformer = transformer + self.scheduler = scheduler + self.text_encoder_2 = text_encoder_2 + self._guidance_scale = None + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + timesteps: List[int] = None, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_samples_per_prompt: Optional[int] = 1, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + guidance_rescale: float = 0.0, + max_sequence_length: int = 512, + ): + # 0. default height and width + num_frames = num_frames or (self.transformer.config.sample_size_t - 1) * self.vae.vae_scale_factor[0] + 1 + height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1] + width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2] + video_size = (num_frames, height, width) + # 1. Check inputs. Raise error if not correct + prompts = (prompt, negative_prompt) + embeds = (prompt_embeds, prompt_embeds_2, negative_prompt_embeds, negative_prompt_embeds_2) + masks = (prompt_attention_mask, prompt_attention_mask_2, + negative_prompt_attention_mask, negative_prompt_attention_mask_2) + + self._check_inputs(prompts, video_size, embeds, masks) + self._guidance_scale = guidance_scale + + # 2. Define call parameters + batch_size = self._get_batch(prompt, prompt_embeds) + device = self.device + + # 3. Encode input prompt + encode_kwarg = {"num_samples_per_prompt":num_samples_per_prompt, + "do_classifier_free_guidance":self.do_classifier_free_guidance, + "max_sequence_length":max_sequence_length} + mask_emb_1 = (prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask) + mask_emb_2 = (prompt_embeds_2, negative_prompt_embeds_2, + prompt_attention_mask_2, negative_prompt_attention_mask_2) + all_mask_emb = (mask_emb_1, mask_emb_2) + # If sp, the prompt_embeds the size [B, S/N, C] + prompt_embeds, prompt_embeds_2, prompt_attention_mask = self._get_embeding(prompts, all_mask_emb, encode_kwarg) + + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + # if sp, the latent [B C [T//N] W H] + latents = self._get_latent(video_size, batch_size, num_samples_per_prompt, latents) + + # 8. Denoising loop + all_guidance = (guidance_scale, guidance_rescale) + input_embeds = prompt_embeds, prompt_embeds_2, prompt_attention_mask + latents = self._sampling(timesteps, latents, input_embeds, all_guidance) + + if not output_type == "latent": + videos = self._decode_latents(latents) + videos = videos[:, :num_frames, :height, :width] + else: + videos = latents + return (videos, ) + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def device(self): + return next(self.transformer.parameters()).device + + def _get_encode_kwarg(self, encode_kwarg): + encode_kwarg = encode_kwarg or {} + num_samples_per_prompt = encode_kwarg.get("num_samples_per_prompt", 1) + do_classifier_free_guidance = encode_kwarg.get("do_classifier_free_guidance", True) + max_sequence_length = encode_kwarg.get("max_sequence_length", None) + return num_samples_per_prompt, do_classifier_free_guidance, max_sequence_length + + def _encode_prompt( + self, + prompts, + mask_emb=None, + encode_kwarg=None, + text_encoder_index: int = 0 + ): + (num_samples_per_prompt, do_classifier_free_guidance, + max_sequence_length) = self._get_encode_kwarg(encode_kwarg) + + (prompt, negative_prompt) = prompts + if mask_emb is None: + mask_emb = [None] * 4 + (prompt_embeds, negative_prompt_embeds, + prompt_attention_mask, negative_prompt_attention_mask) = mask_emb + + device = self.device + dtype = self.transformer.dtype + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + encoder = (tokenizers[text_encoder_index], text_encoders[text_encoder_index], text_encoder_index) + + max_length = self._get_length(max_sequence_length, text_encoder_index) + batch_size = self._get_batch(prompt, prompt_embeds) + encode_kwarg["max_sequence_length"] = max_length + + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._encode_prompt_process( + prompt, encode_kwarg, encoder, trunc_test=True) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_samples_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_samples_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = self._standerd_neg_prompt(prompts, batch_size) + negative_prompt_embeds, negative_prompt_attention_mask = self._encode_prompt_process( + negative_prompt, encode_kwarg, encoder) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_samples_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_samples_per_prompt, seq_len, -1) + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + def _get_length(self, max_sequence_length, text_encoder_index): + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 512 + if text_encoder_index == 1: + max_length = 77 + else: + max_length = max_sequence_length + return max_length + + def _get_batch(self, prompt, prompt_embeds): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + return batch_size + + def _encode_prompt_process(self, prompt, encode_kwarg, encoder, trunc_test=False): + device = self.device + tokenizer, text_encoder, text_encoder_index = encoder + num_samples_per_prompt, _, max_length = self._get_encode_kwarg(encode_kwarg) + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + if trunc_test: + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + print("warning:", ( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + )) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if text_encoder_index == 1: + prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d for clip + + prompt_attention_mask = prompt_attention_mask.repeat(num_samples_per_prompt, 1) + return prompt_embeds, prompt_attention_mask + + def _standerd_neg_prompt(self, prompts, batch_size): + (prompt, negative_prompt) = prompts + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise ValueError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + return uncond_tokens + + def _sampling(self, timesteps, latents, input_embeds, all_guidance): + prompt_embeds, prompt_embeds_2, prompt_attention_mask = input_embeds + guidance_scale, guidance_rescale, = all_guidance + # ==================prepare my shape===================================== + # [B T W H] or [B T/N W H] + attention_mask = torch.ones_like(latents)[:, 0].repeat(2, 1, 1, 1).to(device=self.device) + # If sp, recover attention_mask to the [B T W H] + if get_sequence_parallel_state(): + attention_mask = attention_mask.repeat(1, get_sequence_parallel_size(), 1, 1) + for step_id, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + scale_model = False + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + scale_model = True + # Expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + if isinstance(t, torch.Tensor): + timestep = t.expand(latent_model_input.shape[0]) + else: + timestep = torch.tensor([t] * latent_model_input.shape[0], device=self.device).to( + dtype=latent_model_input.dtype) + + noise_pred = self.transformer( + latent_model_input, + attention_mask=attention_mask, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + pooled_projections=prompt_embeds_2, + step_id=step_id, + )[0] + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0 and scale_model: + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if get_sequence_parallel_state(): + world_size = get_sequence_parallel_size() + latents_shape = list(latents.shape) # b c t//sp h w + full_shape = [latents_shape[0] * world_size] + latents_shape[1:] # # b*sp c t//sp h w + all_latents = torch.zeros(full_shape, dtype=latents.dtype, device=latents.device) + torch.distributed.all_gather_into_tensor(all_latents, latents) + latents_list = list(all_latents.chunk(world_size, dim=0)) + latents = torch.cat(latents_list, dim=2) + return latents + + def _check_inputs( + self, + prompts, + video_size, + embeds, + masks + ): + num_frames, height, width = video_size + # 1. Check inputs. Raise error if not correct + suffix_2 = "_2" + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + for i, prompt in enumerate(prompts): + if prompt is not None and embeds[i] is not None: + prefix = "" if i % 2 == 0 else "negative_" + raise ValueError( + f"Cannot forward both `{prefix}prompt`: {prompt} and `{prefix}prompt_embeds`. Please make sure to" + " only forward one of the two.") + for i in range(2): + if prompts[0] is None and embeds[i] is None: + suffix = "" if i % 2 == 0 else suffix_2 + raise ValueError( + f"Provide either `prompt` or `prompt_embeds{suffix}`. " + f"Cannot leave both `prompt` and `prompt_embeds{suffix}` undefined.") + + if prompts[0] is not None and (not isinstance(prompts[0], str) and not isinstance(prompts[0], list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompts[0])}") + + #embed contain all the + for i, (emb, mask) in enumerate(zip(embeds, masks)): + if emb is not None and mask is None: + prefix = "" if i < 2 else "negative_" + suffix = "" if i % 2 == 0 else suffix_2 + raise ValueError(f"Must provide `{prefix}prompt_attention_mask{suffix}` " + "when specifying `{prefix}prompt_embeds{suffix}`.") + + for i in range(2): + if embeds[i] is not None and embeds[i + 2] is not None: + suffix = "" if i % 2 == 0 else suffix_2 + raise ValueError( + f"`prompt_embeds{suffix}` and `negative_prompt_embeds{suffix}` must have the same shape" + f"when passed directly, but got: `prompt_embeds{suffix}` {embeds[i].shape} != " + f"`negative_prompt_embeds{suffix}` {embeds[i+2].shape}.") + + def _prepare_latents(self, batch_size, num_channels_latents, video_size, latents=None): + num_frames, height, width = video_size + shape = ( + batch_size, + num_channels_latents, + (int(num_frames) - 1) // self.vae.vae_scale_factor[0] + 1, + int(height) // self.vae.vae_scale_factor[1], + int(width) // self.vae.vae_scale_factor[2], + ) + device = self.device + dtype = self.transformer.dtype + + if latents is None: + latents = torch.randn(shape, dtype=dtype, device=device) + else: + latents = latents.to(device) + + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + + return latents + + def _get_embeding(self, prompts, all_mask_emb, encode_kwarg): + device = self.device + mask_emb_1, mask_emb_2 = all_mask_emb + (prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self._encode_prompt(prompts, mask_emb_1, encode_kwarg, text_encoder_index=0) + + if self.tokenizer_2 is not None: + (prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self._encode_prompt(prompts, mask_emb_2, encode_kwarg, text_encoder_index=1) + else: + prompt_embeds_2 = None + negative_prompt_embeds_2 = None + prompt_attention_mask_2 = None + negative_prompt_attention_mask_2 = None + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + if self.tokenizer_2 is not None: + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + if self.tokenizer_2 is not None: + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + # If sp, split prompt_embeds to [B S/N C] + if get_sequence_parallel_state(): + world_size = get_sequence_parallel_size() + prompt_embeds = rearrange(prompt_embeds, 'b (n x) h -> b n x h', + n=world_size, x=prompt_embeds.shape[1] // world_size).contiguous() + rank = get_sequence_parallel_rank() + prompt_embeds = prompt_embeds[:, rank, :, :] + + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d + if prompt_attention_mask.ndim == 2: + prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l + if prompt_embeds_2 is not None and prompt_embeds_2.ndim == 2: + prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d + return prompt_embeds, prompt_embeds_2, prompt_attention_mask + + def _get_latent(self, video_size, batch_size, num_samples_per_prompt, latents): + + (num_frames, height, width) = video_size + world_size = get_sequence_parallel_size() + num_channels_latents = self.transformer.config.in_channels + video_size = ( + (num_frames + world_size - 1) // world_size, + height, width) + + latents = self._prepare_latents( + batch_size * num_samples_per_prompt, + num_channels_latents, video_size, latents) + return latents + + def _decode_latents(self, latents): + video = self.vae.decode(latents.to(self.vae.vae.dtype)) + video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8) + video = video.cpu().permute(0, 1, 3, 4, 2).contiguous() # b t h w c + return video \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/pipeline/pipeline_utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/pipeline/pipeline_utils.py new file mode 100644 index 0000000000..62507a2d6f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/pipeline/pipeline_utils.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import inspect +import importlib +from typing import List, Optional, Union + +from tqdm import tqdm +import torch +from mindiesd import ConfigMixin + + +PIPELINE_CONFIG_NAME = "model_index.json" +VAE = 'vae' +TEXT_ENCODER = 'text_encoder' +TOKENIZER = 'tokenizer' +TRANSFORMER = 'transformer' +SCHEDULER = 'scheduler' + + +class DiffusionPipeline(ConfigMixin): + config_name = PIPELINE_CONFIG_NAME + + def __init__(self): + super().__init__() + + @classmethod + def from_pretrained(cls, model_path, **kwargs): + dtype = kwargs.pop('dtype', None) + real_path = os.path.abspath(model_path) + if not (os.path.exists(real_path) and os.path.isdir(real_path)): + raise ValueError("model path is invalid!") + + init_dict, config_dict = cls.load_config(real_path, **kwargs) + + all_parameters = inspect.signature(cls.__init__).parameters + required_param = {k: v for k, v in all_parameters.items() if v.default is inspect.Parameter.empty} + + # init the module from kwargs + passed_module = {k: kwargs.pop(k) for k in required_param if k in kwargs} + from_diffusers = None if '_diffusers_version' not in config_dict else config_dict['_diffusers_version'] + for key, item in tqdm(init_dict.items(), desc="Loading pipeline components..."): + if key in passed_module: + init_dict[key] = passed_module.pop(key) + else: + modules, cls_name = item + if from_diffusers: + try: + library = importlib.import_module("mindiesd") + class_obj = getattr(library, cls_name) + except ImportError: + print("Warning:", f"Cannot import {cls_name} from mindiesd. Use diffuser.") + library = importlib.import_module(modules) + class_obj = getattr(library, cls_name) + else: + library = importlib.import_module(modules) + class_obj = getattr(library, cls_name) + sub_folder = os.path.join(real_path, key) + if key.startswith(TOKENIZER): + init_dict[key] = class_obj.from_pretrained(sub_folder, **kwargs) + elif key.startswith(SCHEDULER): + init_dict[key] = class_obj.from_config(sub_folder, **kwargs) + else: + init_dict[key] = class_obj.from_pretrained(sub_folder, **kwargs).to(dtype) + + return cls(**init_dict) + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + + + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed." + " Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/utils/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/utils/__init__.py new file mode 100644 index 0000000000..f35da6dcea --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/utils/__init__.py @@ -0,0 +1 @@ +from .utils import set_random_seed \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/utils/utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/utils/utils.py new file mode 100644 index 0000000000..769d4de547 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_3/utils/utils.py @@ -0,0 +1,19 @@ + +import importlib +import random +import torch +import numpy as np + + +def set_random_seed(seed): + """Set random seed. + + Args: + seed (int, optional): Seed to be used. + + """ + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + return seed \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/README.md b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/README.md new file mode 100644 index 0000000000..8f3116bf0a --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/README.md @@ -0,0 +1,213 @@ +## 一、准备运行环境 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | + +### 1.1 获取CANN&MindIE安装包&环境准备 +- 设备支持: +Atlas 800I A2推理设备:支持的卡数最小为1 +- [Atlas 800I A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +### 1.2 CANN安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +### 1.3 环境依赖安装 +```shell +pip3 install -r requirements.txt +``` + +### 1.4 MindIE安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +### 1.5 Torch_npu安装 +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```shell +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` +## 二、下载本仓库 + +### 2.1 下载到本地 +```shell + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +``` + +## 三、OpenSora1.2使用 + +### 3.1 权重及配置文件说明 +1. text_encoder权重链接: +```shell + https://huggingface.co/DeepFloyd/t5-v1_1-xxl/tree/main +``` +2. tokenizer权重链接: +```shell + https://huggingface.co/DeepFloyd/t5-v1_1-xxl/tree/main +``` +3. STDiT3权重链接: +- 下载该权重,并重命名为transformer +```shell + https://huggingface.co/hpcai-tech/OpenSora-STDiT-v3/tree/main +``` +- 修改该权重的config.json +```shell + 将enable_flash_attn设置为true +``` +4. VAE权重链接: +- 下载该权重,并重命名为vae +```shell +https://huggingface.co/hpcai-tech/OpenSora-VAE-v1.2/tree/main +``` +- 修改该权重的config.json +```shell +修改architectures和model_type字段为VideoAutoencoder即可。 +``` +5. VAE_2d: +- 权重链接如下, 下载后将vae_2d的配置文件和权重文件放在vae/vae_2d/vae路径下。 +```shell +https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main +``` +6. scheduler: +- 新增scheduler_config.json配置文件, 内容如下所示: +```shell +{ + "_class_name": "RFlowScheduler", + "_mindiesd_version": "1.0.0", + "num_sampling_steps": 30, + "num_timesteps": 1000 +} +``` +7.新增model_index.json +将以上步骤下载的权重放在同一目录下, 并新增model_index.json文件, 该文件内容如下所示 +```shell +{ + "_class_name": "OpenSoraPipeline", + "_mindiesd_version": "1.0.0", + "scheduler": [ + "mindiesd", + "RFlowScheduler" + ], + "text_encoder": [ + "transformers", + "T5EncoderModel" + ], + "tokenizer": [ + "transformers", + "AutoTokenizer" + ], + "transformer": [ + "mindiesd", + "STDiT3" + ], + "vae": [ + "mindiesd", + "VideoAutoencoder" + ] +} +``` +8.各模型的配置文件、权重文件的层级样例如下所示。 +```commandline +|----open-sora +| |---- model_index.json +| |---- scheduler +| | |---- scheduler_config.json +| |---- text_encoder +| | |---- config.json +| | |---- 模型权重 +| |---- tokenizer +| | |---- config.json +| | |---- 模型权重 +| |---- transformer +| | |---- config.json +| | |---- 模型权重 +| |---- vae +| | |---- config.json +| | |---- 模型权重 +| | |---- vae_2d +| | | |---- vae +| | | | |---- config.json +| | | | |---- 模型权重 +``` + +### 3.2 单卡性能测试 +设置权重路径 +```shell +path = './path' +``` +执行命令: +```shell +python inference_opensora12.py \ + --path ${path} \ + --device_id 0 \ + --type bf16 \ + --num_frames 32 \ + --image_size 720,1280 \ + --fps 8 +``` +参数说明: +- path: 权重路径,包含vae、text_encoder、Tokenizer、Transformer和Scheduler五个模型的配置文件及权重。 +- device_id: 推理设备ID。 +- type: bf16、fp16。 +- num_frames:总帧数,范围:32, 128。 +- image_size:(720, 1280)、(512, 512)。 +- fps: 每秒帧数:8。 +- test_acc: 使用--test_acc开启全量视频生成,用于精度测试。性能测试时,不开启该参数。 + +### 3.3 多卡性能测试 +设置权重路径 +```shell +path = './path' +``` + +执行命令: +```shell +torchrun --nproc_per_node=4 inference_opensora12.py \ + --path ${path} \ + --type bf16 \ + --num_frames 32 \ + --image_size 720,1280 \ + --fps 8 \ + --enable_sequence_parallelism True +``` +参数说明: +- nproc_per_node: 并行推理的总卡数。 +- enable_sequence_parallelism 开启dsp 多卡并行 +- path: 权重路径,包含vae、text_encoder、Tokenizer、Transformer和Scheduler五个模型的配置文件及权重。 +- type: bf16、fp16。 +- num_frames:总帧数,范围:32, 128。 +- image_size:(720, 1280)、(512, 512)。 +- fps: 每秒帧数:8。 + + diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/inference_opensora12.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/inference_opensora12.py new file mode 100644 index 0000000000..37e643e87a --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/inference_opensora12.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import argparse +import time +import logging +import colossalai +import torch +import torch.distributed as dist +from torchvision.io import write_video + +from opensora import set_parallel_manager +from opensora import compile_pipe +from opensora import OpenSoraPipeline12 + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--path", + type=str, + default='/open-sora', + help="The path of all model weights, suach as vae, transformer, text_encoder, tokenizer, scheduler", + ) + parser.add_argument( + "--device_id", + type=int, + default=0, + help="NPU device id", + ) + parser.add_argument( + "--device", + type=str, + default='npu', + help="NPU", + ) + parser.add_argument( + "--type", + type=str, + default='bf16', + help="bf16 or fp16", + ) + parser.add_argument( + "--num_frames", + type=int, + default=32, + help="num_frames: 32 or 128", + ) + parser.add_argument( + "--image_size", + type=str, + default="(720, 1280)", + help="image_size: (720, 1280) or (512, 512)", + ) + parser.add_argument( + "--fps", + type=int, + default=8, + help="fps: 8", + ) + parser.add_argument( + "--enable_sequence_parallelism", + type=bool, + default=False, + help="enable_sequence_parallelism", + ) + parser.add_argument( + "--set_patch_parallel", + type=bool, + default=False, + help="set_patch_parallel", + ) + parser.add_argument( + "--prompts", + type=list, + default=[ + 'A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. \ + She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. \ + She wears sunglasses and red lipstick. She walks confidently and casually. \ + The street is damp and reflective, creating a mirror effect of the colorful lights. \ + Many pedestrians walk about.'], + help="prompts", + ) + parser.add_argument( + "--test_acc", + action="store_true", + help="Run or not.", + ) + return parser.parse_args() + + +def infer(args): + test_acc = args.test_acc + use_time = 0 + torch.npu.set_device(args.device_id) + dtype = torch.bfloat16 + if args.type == 'bf16': + dtype = torch.bfloat16 + elif args.type == 'fp16': + dtype = torch.float16 + else: + logger.error("Not supported.") + + # === Initialize Distributed === + if args.enable_sequence_parallelism or args.set_patch_parallel: + colossalai.launch_from_torch({}) + sp_size = dist.get_world_size() + set_parallel_manager(sp_size, sp_axis=0) + + args.image_size = eval(args.image_size) + + if not test_acc: + prompts = args.prompts + else: + lines_list = [] + with open('./prompts/t2v_sora.txt', 'r') as file: + for line in file: + line = line.strip() + lines_list.append(line) + prompts = lines_list + + if not test_acc: + loops = 5 + else: + loops = len(prompts) + + pipe = OpenSoraPipeline12.from_pretrained(model_path=args.path, + num_frames=args.num_frames, image_size=args.image_size, fps=args.fps, + enable_sequence_parallelism=args.enable_sequence_parallelism, + dtype=dtype, openmind_name="opensora_v1_2") + pipe = compile_pipe(pipe) + + for i in range(loops): + + start_time = time.time() + if test_acc: + video = pipe(prompts=[prompts[i]], output_type="thwc") + + else: + video = pipe(prompts=prompts) + + torch.npu.empty_cache() + + if test_acc: + if i < 10: + save_file_name = "sample_0{}.mp4".format(i) + else: + save_file_name = "sample_{}.mp4".format(i) + save_path = os.path.join(os.getcwd(), save_file_name) + + write_video(save_path, video, fps=8, video_codec="h264") + torch.npu.empty_cache() + else: + if i >= 2: + use_time += time.time() - start_time + logger.info("current_time is %.3f )", time.time() - start_time) + torch.npu.empty_cache() + + if not test_acc: + logger.info("use_time is %.3f)", use_time / 3) + + +if __name__ == "__main__": + inference_args = parse_arguments() + infer(inference_args) + diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/__init__.py new file mode 100644 index 0000000000..b0dcd39a69 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/__init__.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .vae import (VideoAutoencoder, VideoAutoencoderConfig) +from .pipeline import (OpenSoraPipeline12, compile_pipe) +from .schedulers import RFlowScheduler +from .stdit3 import (STDiT3Config, STDiT3) +from .utils import (set_random_seed, append_score_to_prompts, extract_prompts_loop, merge_prompt, prepare_multi_resolution_info, + split_prompt, is_npu_available, exists, default, Patchify, Depatchify) +from .layer import (approx_gelu, CaptionEmbedder, PatchEmbed3D, PositionEmbedding2D, SizeEmbedder, TimestepEmbedder, RotaryEmbedding, + Mlp, AdaLayerNorm, PatchGroupNorm3d, GroupNorm3dAdapter, all_to_all_with_pad, get_spatial_pad, get_temporal_pad, + set_spatial_pad, set_temporal_pad, split_sequence, gather_sequence, set_parallel_manager, get_sequence_parallel_group, + get_sequence_parallel_size, rearrange_flatten_t, rearrange_unflatten_t, Conv3dAdapter, PatchConv3d, Attention, MultiHeadCrossAttention) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/__init__.py new file mode 100644 index 0000000000..153e54956e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .activation import approx_gelu +from .embdding import (CaptionEmbedder, PatchEmbed3D, PositionEmbedding2D, SizeEmbedder, TimestepEmbedder, RotaryEmbedding) +from .mlp import Mlp +from .norm import (AdaLayerNorm, PatchGroupNorm3d, GroupNorm3dAdapter) +from .comm import ( + all_to_all_with_pad, + get_spatial_pad, + get_temporal_pad, + set_spatial_pad, + set_temporal_pad, + split_sequence, + gather_sequence, +) +from .parallel_mgr import (set_parallel_manager, get_sequence_parallel_group, get_sequence_parallel_size) +from .utils import (rearrange_flatten_t, rearrange_unflatten_t) +from .conv import (Conv3dAdapter, PatchConv3d) +from .attention import (Attention, MultiHeadCrossAttention) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/activation.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/activation.py new file mode 100644 index 0000000000..052e0d1031 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/activation.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + + +def approx_gelu(): + return nn.GELU(approximate="tanh") \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/attention.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/attention.py new file mode 100644 index 0000000000..bb6393c6a4 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/attention.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha +# Latte: https://github.com/Vchitect/Latte +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import math +import logging +import inspect +from typing import Optional + +import torch +import torch.nn as nn +import torch_npu + +from .norm import get_normalization_helper, LlamaRMSNorm +from .embdding import get_embedding_helper +from ..utils.utils import is_npu_available + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +MAX_TOKENS = 2147483647 + + +class Attention(nn.Module): + def __init__( + self, + dimension: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + norm_layer: nn.Module = LlamaRMSNorm, + enable_flash_attn: bool = False, + rope=None, + ) -> None: + super().__init__() + if dimension % num_heads != 0: + logger.error("dimension should be divisible by num_heads") + raise ValueError('dimension should be divisible by num_heads') + self.dimension = dimension + self.num_heads = num_heads + self.head_dim = dimension // num_heads + self.scale = self.head_dim ** -0.5 + self.enable_flash_attn = enable_flash_attn + + self.qkv = nn.Linear(dimension, dimension * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.proj = nn.Linear(dimension, dimension) + + self.rope = False + if rope is not None: + self.rope = True + self.rotary_emb = rope + + def t_flash_attention(self, q, k, v): + x = torch_npu.npu_fusion_attention( + q, k, v, self.num_heads, input_layout="BNSD", + pse=None, + scale=self.scale, + pre_tockens=MAX_TOKENS, + next_tockens=MAX_TOKENS, + keep_prob=1., + sync=False, + inner_precise=0, + )[0] + + x = x.transpose(1, 2) + return x + + def s_flash_attention(self, q, k, v): + x = torch_npu.npu_prompt_flash_attention( + q, k, v, num_heads=self.num_heads, + input_layout="BNSD", + scale_value=1.0 / math.sqrt(self.head_dim), + pre_tokens=MAX_TOKENS, + next_tokens=MAX_TOKENS, + sparse_mode=0) + x = x.transpose(1, 2) + return x + + def no_fused_flash_attention(self, q, k, v): + dtype = q.dtype + q = q * self.scale + attn = q @ k.transpose(-2, -1) # translate attn to float32 + attn = attn.to(torch.float32) + attn = attn.softmax(dim=-1) + attn = attn.to(dtype) # cast back attn to original dtype + x = attn @ v + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_shape0_b, x_shape1_n, x_shape2_c = x.shape + enable_flash_attn = self.enable_flash_attn + qkv = self.qkv(x) + qkv_shape = (x_shape0_b, x_shape1_n, 3, self.num_heads, self.head_dim) + + qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + q, k = self.q_norm(q), self.k_norm(k) + if self.rope: + q = self.rotary_emb(q) + k = self.rotary_emb(k) + + if enable_flash_attn: + if is_npu_available() and q.dtype in [torch.float16, torch.bfloat16]: + if self.rope: + x = self.t_flash_attention(q, k, v) + else: + x = self.s_flash_attention(q, k, v) + else: + from flash_attn import flash_attn_func + + # (B, #heads, N, #dim) -> (B, N, #heads, #dim) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + x = flash_attn_func( + q, + k, + v, + softmax_scale=self.scale, + ) + else: + x = self.no_fused_flash_attention(q, k, v) + + x_output_shape = (x_shape0_b, x_shape1_n, x_shape2_c) + if not enable_flash_attn: + x = x.transpose(1, 2) + x = x.reshape(x_output_shape) + x = self.proj(x) + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, d_model, num_heads): + super(MultiHeadCrossAttention, self).__init__() + if num_heads == 0: + logger.error("num_heads cannot be zero") + raise ValueError('num_heads cannot be zero') + if d_model % num_heads != 0: + logger.error("d_model must be divisible by num_heads") + raise ValueError('d_model must be divisible by num_heads') + if d_model // num_heads <= 0: + logger.error("head_dim must be a positive integero") + raise ValueError('head_dim must be a positive integero') + + self.d_model = d_model + self.num_heads = num_heads + self.head_dim = d_model // num_heads + + self.q_linear = nn.Linear(d_model, d_model) + self.kv_linear = nn.Linear(d_model, d_model * 2) + self.proj = nn.Linear(d_model, d_model) + + def forward(self, x, cond, mask=None): + # query/value: img tokens; key: condition; mask: if padding tokens + x_shape0_b, x_shape1_n, x_shape2_c = x.shape + + if is_npu_available() and x.dtype in [torch.float16, torch.bfloat16]: + q = self.q_linear(x).view(-1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(-1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(1) + + actual_seq_qlen = [] + actual_seq_kvlen = [] + if mask is not None: + ans = 0 + for _ in range(x_shape0_b): + ans += x_shape1_n + actual_seq_qlen.append(ans) + ans = 0 + for m in mask: + ans += m + actual_seq_kvlen.append(ans) + + x = torch_npu.npu_fusion_attention( + q, k, v, self.num_heads, input_layout="TND", + pse=None, + scale=1.0 / math.sqrt(self.head_dim), + pre_tockens=MAX_TOKENS, + next_tockens=MAX_TOKENS, + actual_seq_qlen=tuple(actual_seq_qlen), + actual_seq_kvlen=tuple(actual_seq_kvlen), + keep_prob=1., + sparse_mode=0, + )[0] + else: + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + + attn_bias = None + if mask is not None: + attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([x_shape1_n] * x_shape0_b, mask) + x = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=attn_bias) + + x = x.view(x_shape0_b, -1, x_shape2_c) + x = self.proj(x) + return x + + + +class AttnProcessor: + """ + The standard attention processor. + """ + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + if hidden_states is None: + logger.error("`hidden_states` can not be None.") + raise ValueError("`hidden_states` can not be None.") + + # only support BNC now. + if hidden_states.ndim != 3: # 3: BNC. + logger.error("`hidden_states` dim must be 3, but got %d", hidden_states.ndim) + raise ValueError(f"`hidden_states` dim must be 3, but got {hidden_states.ndim}") + + batch_size = hidden_states.shape[0] + + if attn.is_cross_attention: + query = attn.q_proj(hidden_states) + kv = attn.kv_proj(encoder_hidden_states) + key, value = kv.reshape(batch_size, -1, 2, attn.num_heads, attn.head_dim).unbind(2) # B S 2 H + else: + qkv = attn.qkv_proj(hidden_states) + query, key, value = qkv.reshape(batch_size, -1, 3, attn.num_heads, attn.head_dim).unbind(2) # B S 3 H + query = query.reshape(batch_size, -1, attn.num_heads, attn.head_dim).transpose(1, 2) # BNSD + key = key.reshape(batch_size, -1, attn.num_heads, attn.head_dim).transpose(1, 2) # BNSD + value = value.reshape(batch_size, -1, attn.num_heads, attn.head_dim).transpose(1, 2) # BNSD + + # norm q and k + query = attn.norm_q(query) + key = attn.norm_k(key) + + # position embedding q and k + query = attn.position_embedding(query) + key = attn.position_embedding(key) + + # need replaced by dispatch flash_attention function + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=attn.scale_value) + # transform the hidden_states layout from BNSD to BSH + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.num_heads * attn.head_dim) + hidden_states = attn.out_proj(hidden_states) + return hidden_states diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/comm.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/comm.py new file mode 100644 index 0000000000..7b103f83d7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/comm.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +from .parallel_mgr import get_sequence_parallel_size + + +def _all_to_all_func(input_, world_size, process_group, scatter_dim=2, gather_dim=1): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=process_group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +def split_sequence(input_, process_group: dist.ProcessGroup, dim: int, pad: int): + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + if world_size == 1: + return input_ + + if pad > 0: + pad_size = list(input_.shape) + pad_size[dim] = pad + input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim) + + dim_size = input_.size(dim) + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + output = tensor_list[rank].contiguous() + return output + + +def gather_sequence(input_, process_group: dist.ProcessGroup, dim: int, pad: int): + input_ = input_.contiguous() + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + #all gather + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, input_, group=process_group) + + #concat + output = torch.cat(tensor_list, dim=dim) + + if pad > 0: + output = output.narrow(dim, 0, output.size(dim) - pad) + + return output + +# ====== +# Pad +# ====== + +SPTIAL_PAD = 0 +TEMPORAL_PAD = 0 + + +def set_spatial_pad(dim_size: int): + sp_size = get_sequence_parallel_size() + pad = (sp_size - (dim_size % sp_size)) % sp_size + global SPTIAL_PAD + SPTIAL_PAD = pad + + +def get_spatial_pad() -> int: + return SPTIAL_PAD + + +def set_temporal_pad(dim_size: int): + sp_size = get_sequence_parallel_size() + pad = (sp_size - (dim_size % sp_size)) % sp_size + global TEMPORAL_PAD + TEMPORAL_PAD = pad + + +def get_temporal_pad() -> int: + return TEMPORAL_PAD + + +def all_to_all_with_pad( + input_: torch.Tensor, + process_group: dist.ProcessGroup, + **kwargs +): + scatter_dim = kwargs.get("scatter_dim", 2) + gather_dim = kwargs.get("gather_dim", 1) + scatter_pad = kwargs.get("scatter_pad", 0) + gather_pad = kwargs.get("gather_pad", 0) + + if scatter_pad > 0: + pad_shape = list(input_.shape) + pad_shape[scatter_dim] = scatter_pad + pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype) + input_ = torch.cat([input_, pad_tensor], dim=scatter_dim) + + world_size = dist.get_world_size(process_group) + + input_ = _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim) + + if gather_pad > 0: + input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) + + return input_ diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/conv.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/conv.py new file mode 100644 index 0000000000..26a7a45436 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/conv.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple, Union +import torch +import torch.nn as nn +import torch.distributed as dist +from torch import Tensor +from torch.nn import functional as F +from torch.nn import init +from torch.nn.modules.utils import _triple, _reverse_repeat_tuple +from torch.nn.parameter import Parameter +from torch.nn.common_types import _size_3_t + + +class Conv3dAdapter(nn.Module): + def __init__( + self, + conv3d: nn.Conv3d, + is_casual=False, + block_size=2, + ): + super().__init__() + self.module = PatchConv3d( + in_channels=conv3d.in_channels, + out_channels=conv3d.out_channels, + kernel_size=conv3d.kernel_size, + stride=conv3d.stride, + padding=conv3d.padding, + dilation=conv3d.dilation, + groups=conv3d.groups, + bias=conv3d.bias is not None, + padding_mode=conv3d.padding_mode, + device=conv3d.weight.device, + dtype=conv3d.weight.dtype, + block_size=block_size, + is_casual=is_casual, + ) + self.module.weight.data = conv3d.weight.data + if conv3d.bias is not None: + self.module.bias.data = conv3d.bias.data + + def forward(self, x): + return self.module(x) + + +class PatchConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: Union[str, _size_3_t] = 0, + dilation: _size_3_t = 1, + transposed: bool = False, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None, + block_size: Union[int, Tuple[int, int]] = 2, + is_casual: bool = False, + is_overlap: bool = True + ) -> None: + self.padding = padding if isinstance(padding, str) else _triple(padding) + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.kernel_size = _triple(kernel_size) + self.stride = _triple(stride) + self.dilation = _triple(dilation) + self.groups = groups + self.padding_mode = padding_mode + self.block_size = block_size + self.is_casual = is_casual + self.is_overlap = is_overlap + self.rank = 0 + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + + if isinstance(self.padding, str): + self._reversed_padding_repeated_twice = [0, 0] * len(self.kernel_size) + if padding == 'same': + for d, k, i in zip(dilation, self.kernel_size, + range(len(self.kernel_size) - 1, -1, -1)): + total_padding = d * (k - 1) + left_pad = total_padding // 2 + self._reversed_padding_repeated_twice[2 * i] = left_pad + self._reversed_padding_repeated_twice[2 * i + 1] = ( + total_padding - left_pad) + else: + self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2) + # initialize weight and bias + if transposed: + self.weight = Parameter(torch.empty( + (in_channels, out_channels // groups, *self.kernel_size), **factory_kwargs)) + else: + self.weight = Parameter(torch.empty( + (out_channels, in_channels // groups, *self.kernel_size), **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(out_channels, **factory_kwargs)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + ch_in, ch_out, *_ = self.weight.shape + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in = math.prod([item for item in self.kernel_size]) * ch_out + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def forward(self, patch_hidden_state: Tensor, weight: Tensor = None, bias: Tensor = None) -> Tensor: + if weight is None: + return self._conv_forward(patch_hidden_state, self.weight, self.bias) + else: + return self._conv_forward(patch_hidden_state, weight, bias) + + def _one_worldsize_conv(self, padding_mode, patch_hidden_state, weight, bias): + if padding_mode != 'zeros': + return F.conv3d(F.pad(patch_hidden_state, self._reversed_padding_repeated_twice, + mode=padding_mode), weight, bias, self.stride, + _triple(0), self.dilation, self.groups) + return F.conv3d(patch_hidden_state, weight, bias, self.stride, + self.padding, self.dilation, self.groups) + + def _pre_conv_forward(self, patch_hidden_state, shape): + bs, channels, t, h, _ = shape + if self.rank % 2 == 0 and self.rank != 0: + send = patch_hidden_state[..., :1].contiguous() + send_op = dist.P2POp(dist.isend, send, self.rank - 1) + recv = torch.zeros([bs, channels, t, h, 1], + dtype=patch_hidden_state.dtype, device=f"npu:{self.rank}") + recv_op = dist.P2POp(dist.irecv, recv, self.rank - 1) + dist.batch_isend_irecv([send_op, recv_op]) + return recv + elif self.rank % 2 != 0 and self.rank != self.world_size - 1: + send = patch_hidden_state[..., -1:].contiguous() + send_op = dist.P2POp(dist.isend, send, self.rank + 1) + recv = torch.zeros([bs, channels, t, h, 1], + dtype=patch_hidden_state.dtype, device=f"npu:{self.rank}") + recv_op = dist.P2POp(dist.irecv, recv, self.rank + 1) + dist.batch_isend_irecv([send_op, recv_op]) + return recv + return None + + + def _end_conv_forward(self, outputs, shape): + bs_, channels_, t_, h_, _ = shape + if self.rank % 2 == 0: + send = outputs[0][..., -1:].contiguous() + send_op = dist.P2POp(dist.isend, send, self.rank + 1) + recv = torch.zeros([bs_, channels_, t_, h_, 1], + dtype=outputs[0].dtype, device=f"npu:{self.rank}") + recv_op = dist.P2POp(dist.irecv, recv, self.rank + 1) + dist.batch_isend_irecv([send_op, recv_op]) + else: + send = outputs[0][..., :1].contiguous() + send_op = dist.P2POp(dist.isend, send, self.rank - 1) + recv = torch.zeros([bs_, channels_, t_, h_, 1], + dtype=outputs[0].dtype, device=f"npu:{self.rank}") + recv_op = dist.P2POp(dist.irecv, recv, self.rank - 1) + dist.batch_isend_irecv([send_op, recv_op]) + return recv + + def _parallel_conv_forward(self, patch_hidden_state, weight, bias): + shape = patch_hidden_state.shape + bs, channels, t, h, w = shape + patch_hidden_state, padding = self._adjust_padding_for_patch(patch_hidden_state, self.padding) + stride = (w - 1 + self.block_size - 1) // self.block_size + overlap = self.kernel_size[0] // 2 + outputs = [] + recv = None + # P2P communication + for step in range(self.block_size): + start_idx = step * stride + 1 - overlap + end_idx = min((step + 1) * stride + 1 + overlap, w) + if self.rank % 2 == 0: + input_patch = patch_hidden_state[..., w - end_idx:w - start_idx] + else: + input_patch = patch_hidden_state[..., start_idx:end_idx] + + if step == 0: + recv = self._pre_conv_forward(patch_hidden_state, shape) + if step == self.block_size - 1: + if overlap == 1: + input_patch = torch.cat([recv, input_patch], dim=-1) \ + if self.rank % 2 == 0 else torch.cat([input_patch, recv], dim=-1) + recv = self._end_conv_forward(outputs, outputs[0].shape) + + outputs.append(F.conv3d(input_patch, weight, bias, self.stride, padding, self.dilation, self.groups)) + + if step == 0: + if self.rank == 0: + recv = torch.zeros([bs, channels, t, h, 1], + dtype=patch_hidden_state.dtype, device=f"npu:{self.rank}") + elif self.rank == self.world_size - 1: + recv = torch.zeros([bs, channels, t, h, 1], + dtype=patch_hidden_state.dtype, device=f"npu:{self.rank}") + if step == self.block_size - 1: + if self.rank % 2 == 0: + outputs.insert(0, recv) + outputs.reverse() + else: + outputs.insert(0, recv) + + return torch.cat(outputs, dim=-1) + + def _conv_forward(self, patch_hidden_state: Tensor, weight: Tensor, bias: Optional[Tensor]): + self._get_world_size_and_rank() + if (self.world_size == 1): + return self._one_worldsize_conv(self.padding_mode, patch_hidden_state, weight, bias) + else: + return self._parallel_conv_forward(patch_hidden_state, weight, bias) + + def _get_world_size_and_rank(self): + world_size = 1 + rank = 0 + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + self.world_size = world_size + self.rank = rank + + def _adjust_padding_for_patch(self, patch_input, padding): + if self.kernel_size[-1] == 3 and self.is_casual: + patch_input = patch_input[..., 1:-1] + padding = list(padding) + padding[-1] = 0 + return patch_input, tuple(padding) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/embdding.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/embdding.py new file mode 100644 index 0000000000..2d758b3a86 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/embdding.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import math +from math import pi +from typing import Literal, Union, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import nn, einsum, broadcast_tensors, Tensor +from einops import rearrange + +from .mlp import Mlp +from ..utils.utils import exists, default + + +LANG_FREQS = 'lang' +PIXEL_FREQS = 'pixel' +CONSTANT_FREQS = 'constant' + + +def get_embedding_helper(embedding_type: str, embdding_dim: int): + match embedding_type: + case None: + return nn.Identity() + case 'rope': + return RotaryEmbedding(dim=embdding_dim) + case _: + error_msg = "`embdding_type` is not supported!" + raise ValueError(error_msg) + +class PatchEmbed3D(nn.Module): + """Video to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None, + flatten=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, x_shape2_d, x_shape3_h, x_shape4_w = x.size() + if x_shape4_w % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - x_shape4_w % self.patch_size[2])) + if x_shape3_h % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - x_shape3_h % self.patch_size[1])) + if x_shape2_d % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - x_shape2_d % self.patch_size[0])) + + x = self.proj(x) # (B C T H W) + if self.norm is not None: + x_shape2_d, x_size_3, x_sie_4 = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, x_shape2_d, x_size_3, x_sie_4) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC + return x + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) + freqs = freqs.to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if t_freq.dtype != dtype: + t_freq = t_freq.to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class SizeEmbedder(TimestepEmbedder): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size) + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + self.outdim = hidden_size + + @property + def dtype(self): + return next(self.parameters()).dtype + + def forward(self, s, bs): + if s.ndim == 1: + s = s[:, None] + if s.shape[0] != bs: + s = s.repeat(bs // s.shape[0], 1) + b, dims = s.shape[0], s.shape[1] + s = s.reshape(b * dims) + s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype) + s_emb = self.mlp(s_freq) + s_emb = s_emb.view(b, dims, self.outdim) + s_emb = s_emb.view(b, dims * self.outdim) + return s_emb + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__( + self, + in_channels, + hidden_size, + act_layer=nn.GELU(approximate="tanh"), + token_num=120, + ): + super().__init__() + self.y_proj = Mlp( + features_in=in_channels, + features_hidden=hidden_size, + features_out=hidden_size, + act_layer=act_layer, + ) + + self.register_buffer( + "y_embedding", + torch.randn(token_num, in_channels) / in_channels ** 0.5, + ) + + def forward(self, caption): + caption = self.y_proj(caption) + return caption + + +class PositionEmbedding2D(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + half_dim = dim // 2 + inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, x: torch.Tensor, h: int, w: int, scale: Optional[float] = 1.0) -> torch.Tensor: + s_hw = h * w + base_size = round(s_hw ** 0.5) + return self._get_cached_emb(x.device, x.dtype, (h, w), scale, base_size) + + @functools.lru_cache(maxsize=512) + def _get_cached_emb( + self, + device: torch.device, + dtype: torch.dtype, + image_size: Tuple[int, int], + scale: float = 1.0, + base_size: Optional[int] = None, + ): + grid_h = torch.arange(image_size[0], device=device) / scale + grid_w = torch.arange(image_size[1], device=device) / scale + if base_size is not None: + grid_h *= base_size / image_size[0] + grid_w *= base_size / image_size[1] + grid_h, grid_w = torch.meshgrid( + grid_w, + grid_h, + indexing="ij", + ) # here w goes first + grid_h = grid_h.t().reshape(-1) + grid_w = grid_w.t().reshape(-1) + emb_h = self._get_sin_cos_emb(grid_h) + emb_w = self._get_sin_cos_emb(grid_w) + return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype) + + def _get_sin_cos_emb(self, t: torch.Tensor): + out = torch.einsum("i,d->id", t, self.inv_freq) + emb_cos = torch.cos(out) + emb_sin = torch.sin(out) + return torch.cat((emb_sin, emb_cos), dim=-1) + + + +class RotaryEmbedding(nn.Module): + def __init__(self, + dim, + custom_freqs: Optional[Tensor] = None, + freqs_for: Union[ + Literal[LANG_FREQS], + Literal[PIXEL_FREQS], + Literal[CONSTANT_FREQS] + ] = LANG_FREQS, + theta=10000, + max_freq=10, + num_freqs=1, + learned_freq=False, + xpos_scale_base=512, + interpolate_factor=1., + theta_rescale_factor=1., + seq_before_head_dim=False, + cache_if_possible=True + ): + super().__init__() + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + + self.cache_if_possible = cache_if_possible + + self.tmp_store('cached_freqs', None) + self.tmp_store('cached_scales', None) + + self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.tmp_store('dummy', torch.tensor(0)) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + self.interpolate_factor = interpolate_factor + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + self.tmp_store('scale', scale) + + @property + def device(self): + return self.dummy.device + + def tmp_store(self, key, value): + self.register_buffer(key, value, persistent=False) + + def get_seq_pos(self, seq_len, device, dtype, offset=0): + return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor + + def rearrange_nd_2_n1d(self, x, transform_type='n d -> n 1 d'): + if transform_type == 'n d -> n 1 d': + shape = x.shape + x = x.view(shape[0], shape[1]) + return x.view(shape[0], 1, shape[1]) + return x + + def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, freq_seq_len=None): + # 进入这个函数 + seq_dim = default(seq_dim, self.default_seq_dim) + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + if exists(freq_seq_len): + seq_len = freq_seq_len + + freqs = self.forward(self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset), seq_len=seq_len, + offset=offset) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + + return self.apply_rotary_emb(freqs, t, seq_dim=seq_dim) + + def get_axial_freqs(self, *dims): + colon = slice(None) + all_freqs = [] + + for ind, dim in enumerate(dims): + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps=dim, device=self.device) + else: + pos = torch.arange(dim, device=self.device) + + freqs = self.forward(pos, seq_len=dim) + + all_axis = [None] * len(dims) + all_axis[ind] = colon + + new_axis_slice = (Ellipsis, *all_axis, colon) + all_freqs.append(freqs[new_axis_slice]) + + all_freqs = broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim=-1) + + def rotate_half(self, x): + shape = x.shape + new_shape = shape[:-1] + (shape[-1] // 2, 2) + x = x.view(new_shape) + + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + shape = x.shape + new_shape = shape[:-2] + (shape[-1] * shape[-2],) + x = x.view(new_shape) + return x + + def apply_rotary_emb(self, freqs, t, start_index=0, scale=1., seq_dim=-2): + if t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:].to(t) + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + + cos = freqs.cos() * scale + sin = freqs.sin() * scale + t = (t * cos) + (self.rotate_half(t) * sin) + + return torch.cat((t_left, t, t_right), dim=-1) + + def forward( + self, + t: Tensor, + seq_len=None, + offset=0 + ): + should_cache = ( + self.cache_if_possible and \ + not self.learned_freq and \ + exists(seq_len) and \ + self.freqs_for != 'pixel' + ) + + if ( + should_cache and \ + exists(self.cached_freqs) and \ + (offset + seq_len) <= self.cached_freqs.shape[0] + ): + return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = torch.repeat_interleave(freqs, repeats=2, dim=-1) + + if should_cache: + self.tmp_store('cached_freqs', freqs.detach()) + + return freqs \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/mlp.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/mlp.py new file mode 100644 index 0000000000..98f5dcba9f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/mlp.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections.abc +from itertools import repeat +from functools import partial + +import torch.nn as nn + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + + def __init__( + self, + features_in, + features_hidden=None, + features_out=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + use_conv=False, + ): + super().__init__() + features_out = features_out or features_in + features_hidden = features_hidden or features_in + to_2tuple = self._ntuple(2) + bias = to_2tuple(bias) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(features_in, features_hidden, bias=bias[0]) + self.act = act_layer() + self.norm = norm_layer(features_hidden) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(features_hidden, features_out, bias=bias[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.fc2(x) + return x + + def _ntuple(self, n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/norm.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/norm.py new file mode 100644 index 0000000000..09e1a3212f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/norm.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +import torch_npu +from torch.nn.parameter import Parameter +import torch.distributed as dist +from ..utils import is_npu_available + + +def get_normalization_helper(norm_type: str, norm_dim: int, eps: float = 1e-5): + match norm_type: + case None: + return nn.Identity() + case 'layer_norm': + return nn.LayerNorm(norm_dim, eps=eps) + case 'llama_rms_norm': + return LlamaRMSNorm(norm_dim, eps=eps) + case _: + error_msg = "`norm_type` is not supported!" + raise ValueError(error_msg) + +class AdaLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + def forward(self, x, shift, scale): + if is_npu_available(): + return torch_npu.npu_layer_norm_eval( + x, normalized_shape=[self.hidden_size], weight=scale, bias=shift, eps=self.eps) + else: + return F.layer_norm(x, normalized_shape=[self.hidden_size], weight=scale, bias=shift, eps=self.eps) + +class GroupNorm3dAdapter(nn.Module): + def __init__(self, group_norm: nn.GroupNorm): + super().__init__() + self.module = PatchGroupNorm3d( + num_groups=group_norm.num_groups, + num_channels=group_norm.num_channels, + eps=group_norm.eps, + affine=group_norm.affine + ) + if group_norm.affine: + self.module.weight = group_norm.weight + self.module.bias = group_norm.bias + + def forward(self, x): + return self.module(x) + +class PatchGroupNorm3d(nn.Module): + def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True, + device=None, dtype=None) -> None: + super().__init__() + self.factory_kwargs = {'device': device, 'dtype': dtype} + if num_channels % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + self.init_paramsters(num_groups, num_channels, eps, affine) + if self.affine: + self.init_weight_bias() + else: + self.init_register_parameter() + + self.reset_parameters() + + def init_paramsters(self, num_groups, num_channels, eps, affine): + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.affine = affine + + def init_weight_bias(self): + self.weight = Parameter(torch.empty(self.num_channels, self.factory_kwargs)) + self.bias = Parameter(torch.empty(self.num_channels, self.factory_kwargs)) + + def init_register_parameter(self): + self.register_parameter('weight', None) + self.register_parameter('bias', None) + + def reset_parameters(self) -> None: + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + rank = dist.get_rank() + width = torch.tensor(x.shape[-1], dtype=torch.int64, device=x.device) - 1 + dist.all_reduce(width) + + channels_per_group = x.shape[1] // self.num_groups + nelements_rank = channels_per_group * x.shape[-3] * x.shape[-2] * (x.shape[-1] - 1) + nelements = channels_per_group * x.shape[-3] * x.shape[-2] * width + + x = x.view(x.shape[0], self.num_groups, -1, *x.shape[2:]) + if rank % 2 == 0: + group_sum = x[..., :-1].sum(dim=(2, 3, 4, 5), dtype=x.dtype, keepdim=True) + else: + group_sum = x[..., 1:].sum(dim=(2, 3, 4, 5), dtype=x.dtype, keepdim=True) + dist.all_reduce(group_sum) + avg = (group_sum / nelements).to(x.dtype) + + group_var_sum = torch.empty((x.shape[0], self.num_groups), dtype=x.dtype, device=x.device) + if rank % 2 == 0: + torch.var(x[..., :-1], dim=(2, 3, 4, 5), out=group_var_sum, keepdim=True) + else: + torch.var(x[..., 1:], dim=(2, 3, 4, 5), out=group_var_sum, keepdim=True) + group_var_sum = group_var_sum * (nelements_rank - 1) + dist.all_reduce(group_var_sum) + var = (group_var_sum / (nelements - 1)).to(x.dtype) + + x = (x - avg) / torch.sqrt(var + self.eps) + x = x.view(x.shape[0], -1, *x.shape[3:]) + x = x * self.weight[None, :, None, None, None] + self.bias[None, :, None, None, None] + return x + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + if is_npu_available(): + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/parallel_mgr.py new file mode 100644 index 0000000000..465b56b7b6 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/parallel_mgr.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch.distributed as dist +from colossalai.cluster.process_group_mesh import ProcessGroupMesh +from torch.distributed import ProcessGroup + +PARALLEL_MANAGER = None + + +class ParallelManager(ProcessGroupMesh): + def __init__(self, sp_size, sp_axis): + super().__init__(sp_size) + self.sp_size = sp_size + self.sp_axis = sp_axis + self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis) + self.sp_rank = dist.get_rank(self.sp_group) + self.enable_sp = sp_size > 1 + + +def set_parallel_manager(sp_size, sp_axis): + global PARALLEL_MANAGER + PARALLEL_MANAGER = ParallelManager(sp_size, sp_axis) + + +def get_sequence_parallel_group(): + return PARALLEL_MANAGER.sp_group + + +def get_sequence_parallel_size(): + return PARALLEL_MANAGER.sp_size + + +def get_sequence_parallel_rank(): + return PARALLEL_MANAGER.sp_rank + + +def use_sequence_parallel(): + return PARALLEL_MANAGER.enable_sp + + +def get_parallel_manager(): + return PARALLEL_MANAGER \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/utils.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/utils.py new file mode 100644 index 0000000000..f850deacf9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/layer/utils.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def rearrange_flatten_t(x): + x_shape = x.shape + x = x.transpose(1, 2) + return x.view((x_shape[0] * x_shape[2]), x_shape[1], x_shape[3], x_shape[4]) + + +def rearrange_unflatten_t(x, b): + x_shape = x.shape + x = x.view(b, x_shape[0] // b, x_shape[1], x_shape[2], x_shape[3]) + return x.transpose(1, 2) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/__init__.py new file mode 100644 index 0000000000..61b98d663c --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .open_sora_pipeline import OpenSoraPipeline12 +from .compile_pipe import compile_pipe \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/compile_pipe.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/compile_pipe.py new file mode 100644 index 0000000000..493af6ffd8 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/compile_pipe.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch.nn as nn +from ..utils import is_npu_available + + +def compile_pipe(pipe, cfg=None): + if is_npu_available(): + device = 'npu' + if hasattr(pipe, "text_encoder") and isinstance(pipe.text_encoder, nn.Module): + pipe.text_encoder.to(device) + if hasattr(pipe, "transformer") and isinstance(pipe.transformer, nn.Module): + pipe.transformer.to(device) + if hasattr(pipe, "vae") and isinstance(pipe.vae, nn.Module): + pipe.vae.to(device) + return pipe + else: + raise RuntimeError("NPU is not available.") \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/open_sora_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/open_sora_pipeline.py new file mode 100644 index 0000000000..d9decdd169 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/open_sora_pipeline.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, List + +import torch +import torch_npu +from tqdm import tqdm +from torch import Tensor + +from transformers import AutoTokenizer, T5EncoderModel + +from .pipeline_utils import OpenSoraPipeline +from ..utils import ( + set_random_seed, append_score_to_prompts, extract_prompts_loop, + merge_prompt, prepare_multi_resolution_info, split_prompt, is_npu_available) +from ..stdit3 import STDiT3 +from ..vae import VideoAutoencoder +from ..schedulers import RFlowScheduler + +torch_npu.npu.config.allow_internal_format = False +NUM_FRAMES = 'num_frames' + + +target_image_size = [(720, 1280), (512, 512)] +target_num_frames = [32, 128] +target_fps = [8] +target_output_type = ["latent", "thwc"] +target_dtype = [torch.bfloat16, torch.float16] +MAX_PROMPT_LENGTH = 1024 # the limits of open-sora1.2 + + +class OpenSoraPipeline12(OpenSoraPipeline): + + def __init__(self, text_encoder: T5EncoderModel, tokenizer: AutoTokenizer, transformer: STDiT3, + vae: VideoAutoencoder, scheduler: RFlowScheduler, + num_frames: int = 32, image_size: Tuple[int, int] = (720, 1280), fps: int = 8, + dtype: torch.dtype = torch.bfloat16): + + super().__init__() + torch.set_grad_enabled(False) + + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.transformer = transformer + self.vae = vae + self.scheduler = scheduler + self.num_frames = num_frames + self.image_size = image_size + self.fps = fps + + if is_npu_available(): + self.device = 'npu' + else: + self.device = 'cpu' + + self.dtype = dtype + self.text_encoder.to(self.dtype) + + @torch.no_grad() + def __call__(self, prompts: List[str], seed: int = 42, output_type: str = "latent"): + + set_random_seed(seed=seed) + + # 1.0 Encode input prompt + text_encoder_res_list = self._encode_prompt(prompts, self.text_encoder) + torch.npu.empty_cache() + + input_size = (self.num_frames, *self.image_size) + latent_size = self.vae.get_latent_size(input_size) + + batch_size = 1 + num_sample = 1 + + # == Iter over all samples == + all_videos = [] + for i in range(0, len(prompts), batch_size): + # == prepare batch prompts == + batch_prompts = prompts[i: i + batch_size] + + # == multi-resolution info == + model_args = prepare_multi_resolution_info( + 'STDiT2', (len(batch_prompts), self.image_size, self.num_frames, self.fps), self.device, self.dtype) + + # == Iter over number of sampling for one prompt == + for _ in range(num_sample): + # == Iter over loop generation == + z = torch.randn(len(batch_prompts), self.vae.out_channels, + *latent_size, device=self.device, dtype=self.dtype) + + # 2.0 Prepare timesteps + timesteps = self._retrieve_timesteps(z, additional_args=model_args, ) + + samples = self._sample( + self.transformer, + text_encoder_res_list[i], + z=z, + timesteps=timesteps, + additional_args=model_args, + ) + + del z, timesteps, text_encoder_res_list + + samples = self.vae.decode(samples.to(self.dtype), num_frames=self.num_frames) + all_videos.append(samples) + + del samples + torch.npu.empty_cache() + + if not output_type == "latent": + videos = self._video_write(all_videos) + return videos + else: + return all_videos + + def _video_write(self, x): + x = [x[0][0]] + x = torch.cat(x, dim=1) + value_range = (-1, 1) + low, high = value_range + x.clamp_(min=low, max=high) + x.sub_(low).div_(max(high - low, 1e-5)) + x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8) + return x + + def _retrieve_timesteps(self, z, additional_args=None, ): + # prepare timesteps + timesteps = [(1.0 - i / self.scheduler.num_sampling_steps) * self.scheduler.num_timesteps + for i in range(self.scheduler.num_sampling_steps)] + if self.scheduler.use_discrete_timesteps: + timesteps = [int(round(t)) for t in timesteps] + timesteps = [torch.tensor([t] * z.shape[0], device=self.device) for t in timesteps] + + if self.scheduler.use_timestep_transform: + timesteps = [self._timestep_transform(t, additional_args, + num_timesteps=self.scheduler.num_timesteps) for t in timesteps] + return timesteps + + def _sample(self, model, model_args, z, timesteps, additional_args=None, ): + + if additional_args is not None: + model_args.update(additional_args) + + for i in tqdm(range(0, len(timesteps), 1)): + t = timesteps[i] + model_args['t_idx'] = i + + # classifier-free guidance + z_in = torch.cat([z, z], 0) + t = torch.cat([t, t], 0) + pred = model(z_in, t, **model_args).chunk(2, dim=1)[0] + z = self.scheduler.step(pred, timesteps, i, z) + return z + + def _timestep_transform(self, t, model_kwargs, num_timesteps): + base_resolution = 512 * 512 + scale = 1.0 + + t = t / num_timesteps + resolution = model_kwargs["height"].to(torch.float32) * model_kwargs["width"].to(torch.float32) + ratio_space = (resolution / base_resolution).sqrt() + # NOTE: currently, we do not take fps into account + # NOTE: temporal_reduction is hardcoded, this should be equal to the temporal reduction factor of the vae + if model_kwargs[NUM_FRAMES][0] == 1: + num_frames = torch.ones_like(model_kwargs[NUM_FRAMES]) + else: + num_frames = model_kwargs[NUM_FRAMES] // 17 * 5 + ratio_time = num_frames.sqrt() + + ratio = ratio_space * ratio_time * scale + new_t = ratio * t / (1 + (ratio - 1) * t) + + new_t = new_t * num_timesteps + return new_t + + def _encode(self, text): + caption_embs, emb_masks = self._get_text_embeddings(text) + caption_embs = caption_embs[:, None] + return dict(y=caption_embs, mask=emb_masks) + + def _null(self, n): + null_y = self.transformer.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None] + return null_y + + def _get_text_embeddings(self, texts): + text_tokens_and_mask = self.tokenizer( + texts, + max_length=300, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + input_ids = text_tokens_and_mask["input_ids"].to(self.device) + attention_mask = text_tokens_and_mask["attention_mask"].to(self.device) + with torch.no_grad(): + text_encoder_embs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + )["last_hidden_state"].detach() + return text_encoder_embs, attention_mask + + def _extract_text_res(self, prompts, text_encoder): + text_encoder_res = self._encode(prompts) + n = len(prompts) + y_null = self._null(n) + + text_encoder_res["y"] = torch.cat([text_encoder_res["y"], y_null], 0) + return text_encoder_res + + def _encode_prompt(self, prompts, text_encoder): + cfg_aes = 6.5 + cfg_flow = None + cfg_camera_motion = None + text_encoder_res_list = [] + for i in range(len(prompts)): + # == prepare batch prompts == + batch_prompts = prompts[i: i + 1] + + # 0. split prompt + # each element in the list is [prompt_segment_list, loop_idx_list] + batched_prompt_segment_list = [] + batched_loop_idx_list = [] + for prompt in batch_prompts: + prompt_segment_list, loop_idx_list = split_prompt(prompt) + batched_prompt_segment_list.append(prompt_segment_list) + batched_loop_idx_list.append(loop_idx_list) + + # append score + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = append_score_to_prompts( + prompt_segment_list, + aes=cfg_aes, + flow=cfg_flow, + camera_motion=cfg_camera_motion, + ) + + # merge to obtain the final prompt + batch_prompts = [] + for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list): + batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list)) + + batch_prompts_loop = extract_prompts_loop(batch_prompts) + + text_encoder_res_list.append(self._extract_text_res(batch_prompts_loop, text_encoder)) + + return text_encoder_res_list \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/pipeline_utils.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/pipeline_utils.py new file mode 100644 index 0000000000..0f0c76bc77 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/pipeline/pipeline_utils.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import inspect +import logging +import importlib +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from tqdm import tqdm + +from mindiesd import ConfigMixin + + +PIPELINE_CONFIG_NAME = "model_index.json" +OPEN_SORA_DEFAULT_VIDEO_FRAME = 32 +OPEN_SORA_DEFAULT_IMAGE_SIZE = (720, 1280) +OPEN_SORA_DEFAULT_FPS = 8 +ENABLE_SEQUENCE_PARALLELISM_DEFAULT_VALUE = False + +VAE = 'vae' +TEXT_ENCODER = 'text_encoder' +TOKENIZER = 'tokenizer' +TRANSFORMER = 'transformer' +SCHEDULER = 'scheduler' +NUM_FRAMES = 'num_frames' +IMAGE_SIZE = 'image_size' +ENABLE_SEQUENCE_PARALLELISM = 'enable_sequence_parallelism' +FPS = 'fps' +DTYPE = 'dtype' +SET_PATCH_PARALLEL = 'set_patch_parallel' +FROM_PRETRAINED = 'from_pretrained' +logger = logging.getLogger(__name__) # init python log + + +class OpenSoraPipeline(ConfigMixin): + r""" + Base class for all OpenSora pipelines. + The OpenSoraPipeline class is mainly provides `from_pretrained` method to + initialize the OpenSora pipeline components from `config_name` file, + and loads the weights of the components from `model_path` directory. + """ + + config_name = PIPELINE_CONFIG_NAME + + def __init__(self): + super().__init__() + + @classmethod + def from_pretrained(cls, model_path, **kwargs): + r""" + The from_pretrained class method is used to initialize the OpenSora pipeline components + and loads the weights of the components from the model directory. + + Args: + model_path (str): The path to the model directory. + **kwargs: Additional keyword arguments for the pipeline components. + + Returns: + OpenSoraPipeline: The initialized OpenSora pipeline. + """ + num_frames = kwargs.pop(NUM_FRAMES, OPEN_SORA_DEFAULT_VIDEO_FRAME) + image_size = kwargs.pop(IMAGE_SIZE, OPEN_SORA_DEFAULT_IMAGE_SIZE) + fps = kwargs.pop(FPS, OPEN_SORA_DEFAULT_FPS) + enable_sequence_parallelism = kwargs.pop(ENABLE_SEQUENCE_PARALLELISM, + ENABLE_SEQUENCE_PARALLELISM_DEFAULT_VALUE) + set_patch_parallel = kwargs.pop(SET_PATCH_PARALLEL, False) + dtype = kwargs.pop("dtype", torch.bfloat16) + init_dict, _ = cls.load_config(model_path, **kwargs) + + init_list = [VAE, TEXT_ENCODER, TOKENIZER, TRANSFORMER, SCHEDULER] + pipe_init_dict = {} + + all_parameters = inspect.signature(cls.__init__).parameters + + required_param = {k: v for k, v in all_parameters.items() if v.default == inspect.Parameter.empty} + expected_modules = set(required_param.keys()) - {"self"} + # init the module from kwargs + passed_module = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + + for key in tqdm(init_list, desc="Loading open-sora-pipeline components"): + if key not in init_dict: + raise ValueError(f"Failed to get {key} from init config!") + if key in passed_module: + pipe_init_dict[key] = passed_module.pop(key) + else: + modules, cls_name = init_dict[key] + if modules == "mindiesd": + library = importlib.import_module("opensora") + else: + library = importlib.import_module(modules) + class_obj = getattr(library, cls_name) + + sub_folder = os.path.join(model_path, key) + + from_pretrained = kwargs.pop(FROM_PRETRAINED, sub_folder) + + if key == TRANSFORMER: + _check(pipe_init_dict) + latent_size = pipe_init_dict.get(VAE).get_latent_size((num_frames, *image_size)) + in_channels = pipe_init_dict.get(VAE).out_channels + caption_channels = pipe_init_dict.get(TEXT_ENCODER).config.d_model + pipe_init_dict[key] = class_obj.from_pretrained(sub_folder, input_size=latent_size, + in_channels=in_channels, caption_channels=caption_channels, + enable_sequence_parallelism=enable_sequence_parallelism, dtype=dtype, **kwargs) + else: + initializer = _get_initializers().get(key, init_default) + pipe_init_dict[key] = initializer(class_obj, sub_folder, from_pretrained, + set_patch_parallel, dtype=dtype, **kwargs) + + pipe_init_dict[NUM_FRAMES] = num_frames + pipe_init_dict[IMAGE_SIZE] = image_size + pipe_init_dict[FPS] = fps + pipe_init_dict[DTYPE] = dtype + + return cls(**pipe_init_dict) + + +def _get_initializers(): + initializers = { + TEXT_ENCODER: init_text_encoder, SCHEDULER: init_scheduler, + VAE: init_vae, TOKENIZER: init_default + } + return initializers + + +def _check(pipe_init_dict): + if pipe_init_dict.get(VAE) is None: + raise ValueError("Cannot get module 'vae' in init list!") + if not hasattr(pipe_init_dict.get(VAE), 'get_latent_size'): + raise ValueError("Vae has no attribute 'get_latent_size'!") + if pipe_init_dict.get(TEXT_ENCODER) is None: + raise ValueError("Cannot get module 'text_encoder' in init list!") + + +def init_text_encoder(class_obj, sub_folder, from_pretrained, set_patch_parallel, **kwargs): + return class_obj.from_pretrained(sub_folder, local_files_only=True).to(kwargs.get(DTYPE)) + + +def init_scheduler(class_obj, sub_folder, from_pretrained, set_patch_parallel, **kwargs): + return class_obj.from_config(sub_folder) + + +def init_vae(class_obj, sub_folder, from_pretrained, set_patch_parallel, **kwargs): + vae = class_obj.from_pretrained(sub_folder, + from_pretrained=from_pretrained, + set_patch_parallel=set_patch_parallel, + **kwargs) + vae.to(kwargs.get(DTYPE)) + return vae + + +def init_default(class_obj, sub_folder, from_pretrained, set_patch_parallel, **kwargs): + return class_obj.from_pretrained(sub_folder, local_files_only=True, **kwargs) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/schedulers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/schedulers/__init__.py new file mode 100644 index 0000000000..f04cedc480 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/schedulers/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .rectified_flow import RFlowScheduler \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/schedulers/rectified_flow.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/schedulers/rectified_flow.py new file mode 100644 index 0000000000..d12b1c0cf7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/schedulers/rectified_flow.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List + +import torch +from torch.distributions import LogisticNormal + +from mindiesd.schedulers.scheduler_utils import DiffusionScheduler + +logging.basicConfig(level=logging.ERROR) +logger = logging.getLogger(__name__) + +# some code are inspired by +# https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py +# and https://github.com/magic-research/piecewise-rectified-flow/blob/main/src/scheduler_perflow.py + +UNIFORM_CONSTANT = "uniform" +LOGIT_NORMAL_CONSTANT = "logit-normal" + + +class RFlowScheduler(DiffusionScheduler): + def __init__( + self, + num_timesteps: int = 1000, + num_sampling_steps: int = 30, + ): + + super().__init__() + + self.num_timesteps = num_timesteps + self.num_sampling_steps = num_sampling_steps + self.use_discrete_timesteps = False + self.sample_method = UNIFORM_CONSTANT + self.loc = 0.0 + self.scale = 1.0 + self.use_timestep_transform = True + self.transform_scale = 1.0 + + # sample method + if self.sample_method not in [UNIFORM_CONSTANT, LOGIT_NORMAL_CONSTANT]: + logger.error("sample_method must be either 'uniform' or 'logit-normal'") + raise ValueError("sample_method must be either 'uniform' or 'logit-normal'") + + if self.sample_method != UNIFORM_CONSTANT and self.use_discrete_timesteps: + logger.error("Only uniform sampling is supported for discrete timesteps") + raise ValueError("Only uniform sampling is supported for discrete timesteps") + + if self.sample_method == LOGIT_NORMAL_CONSTANT: + self.distribution = LogisticNormal(torch.tensor([self.loc]), torch.tensor([self.scale])) + self.sample_t = self._sample_t_function + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + timepoints = timesteps.float() / self.num_timesteps + timepoints = 1 - timepoints # [1,1/1000] + + # timepoint (bsz) noise: (bsz, 4, frame, w ,h) + # expand timepoint to noise shape + timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) + timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) + res = timepoints * original_samples + (1 - timepoints) * noise + + return res + + def step(self, pred: torch.FloatTensor, + timesteps: List[float], + i: int, + noise: torch.FloatTensor, + guidance_scale: float = 7.0): + + pred_cond, pred_uncond = pred.chunk(2, dim=0) + v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + # update z + dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i] + dt = dt / self.num_timesteps + noise = noise + v_pred * dt[:, None, None, None, None] + return noise + + def _sample_t_function(self, x): + return self.distribution.sample((x.shape[0],))[:, 0].to(x.device) + + diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/stdit3/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/stdit3/__init__.py new file mode 100644 index 0000000000..87101e574f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/stdit3/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .stdit3 import (STDiT3Config, STDiT3) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/stdit3/stdit3.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/stdit3/stdit3.py new file mode 100644 index 0000000000..e589920628 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/stdit3/stdit3.py @@ -0,0 +1,563 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange + +from mindiesd.config_utils import ConfigMixin +from mindiesd.models.model_utils import DiffusionModel +from ..layer import approx_gelu +from ..layer import Attention, MultiHeadCrossAttention +from ..layer import CaptionEmbedder, PatchEmbed3D, PositionEmbedding2D, SizeEmbedder, \ + TimestepEmbedder, RotaryEmbedding +from ..layer import Mlp +from ..layer import AdaLayerNorm +from ..layer import ( + all_to_all_with_pad, + get_spatial_pad, + get_temporal_pad, + set_spatial_pad, + set_temporal_pad, + split_sequence, + gather_sequence, +) +from ..layer import get_sequence_parallel_group + +MAX_IN_CHANNELS = 4 +MAX_CAPTIOIN_CHANNELS = 4096 + + +class STDiT3Config(ConfigMixin): + config_name = 'config.json' + + def __init__( + self, + input_size: Tuple[int, int, int] = (None, None, None), + in_channels: int = 4, + caption_channels: int = 4096, + enable_flash_attn: bool = True, + enable_sequence_parallelism: bool = False, + use_cache: bool = True, + cache_interval: int = 2, + cache_start: int = 3, + num_cache_layer: int = 13, + cache_start_steps: int = 5, + ): + super().__init__() + + self.input_size = input_size + self.in_channels = in_channels + self.caption_channels = caption_channels + self.enable_flash_attn = enable_flash_attn + self.enable_sequence_parallelism = enable_sequence_parallelism + self.use_cache = use_cache + self.cache_interval = cache_interval + self.cache_start = cache_start + self.num_cache_layer = num_cache_layer + self.cache_start_steps = cache_start_steps + + +class T2IFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None): + super().__init__() + self.norm_final = AdaLayerNorm(hidden_size, eps=1e-6) + + self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) + self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5) + self.out_channels = out_channels + self.d_t = d_t + self.d_s = d_s + + def t_mask_select(self, x_mask, x, masked_x, t, s): + # x: [B, (T, S), C], mased_x: [B, (T, S), C], x_mask: [B, T] + x = rearrange(x, "b (t s) c -> b t s c", t=t, s=s) + masked_x = rearrange(masked_x, "b (t s) c -> b t s c", t=t, s=s) + x = torch.where(x_mask[:, :, None, None], x, masked_x) + x = rearrange(x, "b t s c -> b (t s) c") + return x + + def forward(self, x, t, x_mask=None, t0=None, t_s=(None, None)): + d_t = t_s[0] + d_s = t_s[1] + if d_t is None: + d_t = self.d_t + if d_s is None: + d_s = self.d_s + shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) + x = self.norm_final(x, shift, 1 + scale[0]) + + if x_mask is not None: + shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1) + x_zero = self.norm_final(x, shift_zero, 1 + scale_zero) + x = self.t_mask_select(x_mask, x, x_zero, d_t, d_s) + x = self.linear(x) + return x + + +class STDiT3Block(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + rope=None, + qk_norm=False, + temporal=False, + enable_flash_attn=False, + enable_sequence_parallelism=False, + ): + super().__init__() + self.temporal = temporal + self.hidden_size = hidden_size + self.enable_flash_attn = enable_flash_attn + self.enable_sequence_parallelism = enable_sequence_parallelism + + attn_cls = Attention + mha_cls = MultiHeadCrossAttention + + self.norm1 = AdaLayerNorm(hidden_size, eps=1e-6) + self.attn = attn_cls( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + qk_norm=qk_norm, + rope=rope, + enable_flash_attn=enable_flash_attn, + ) + + self.cross_attn = mha_cls(hidden_size, num_heads) + self.norm2 = AdaLayerNorm(hidden_size, eps=1e-6) + + self.mlp = Mlp( + features_in=hidden_size, features_hidden=int(hidden_size * mlp_ratio), act_layer=approx_gelu) + self.scale_shift_table = nn.Parameter(torch.zeros(6, hidden_size) / hidden_size ** 0.5) + + def t_mask_select(self, x_mask, x, masked_x, t, s): + # x: [B, (T, S), C], mased_x: [B, (T, S), C], x_mask: [B, T] + x = rearrange(x, "b (t s) c -> b t s c", t=t, s=s) + masked_x = rearrange(masked_x, "b (t s) c -> b t s c", t=t, s=s) + x = torch.where(x_mask[:, :, None, None], x, masked_x) + x = rearrange(x, "b t s c -> b (t s) c") + return x + + def forward( + self, x, y, t, mask=None, x_mask=None, t0=None, number_frames=None, number_pixel_patches=None + ): + # prepare modulate parameters + x_shape0_b, x_shape1_n, x_shape2_c = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + t.reshape(x_shape0_b, 6, -1) + ).chunk(6, dim=1) + + if x_mask is not None: + shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = ( + self.scale_shift_table[None] + t0.reshape(x_shape0_b, 6, -1) + ).chunk(6, dim=1) + + # modulate attention + x_m = self.norm1(x, shift_msa, 1 + scale_msa[0]) + if x_mask is not None: + x_m_zero = self.norm1(x, shift_msa_zero, scale_msa_zero) + x_m = self.t_mask_select(x_mask, x_m, x_m_zero, number_frames, number_pixel_patches) + + # modulate attention + if self.temporal: + if self.enable_sequence_parallelism: + x_m, number_pixel_patches, number_frames = self.dynamic_switch( + x_m, number_pixel_patches, number_frames, temporal_to_spatial=True) + x_m = rearrange(x_m, "b (t s) c -> (b s) t c", t=number_frames, s=number_pixel_patches) + x_m = self.attn(x_m) + x_m = rearrange(x_m, "(b s) t c -> b (t s) c", t=number_frames, s=number_pixel_patches) + # because x_mask split on the dim 1 + if self.enable_sequence_parallelism: + x_m, number_pixel_patches, number_frames = self.dynamic_switch( + x_m, number_pixel_patches, number_frames, temporal_to_spatial=False) + + else: + x_m = rearrange(x_m, "b (t s) c -> (b t) s c", t=number_frames, s=number_pixel_patches) + x_m = self.attn(x_m) + x_m = rearrange(x_m, "(b t) s c -> b (t s) c", t=number_frames, s=number_pixel_patches) + + # modulate attention + x_m_s = gate_msa * x_m + if x_mask is not None: + x_m_s_zero = gate_msa_zero * x_m + x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, number_frames, number_pixel_patches) + + # residual + x = x + x_m_s + + # cross attention + x = x + self.cross_attn(x, y, mask) + + # modulate MLP + x_m = self.norm2(x, shift_mlp, 1 + scale_mlp[0]) + if x_mask is not None: + x_m_zero = self.norm2(x, shift_mlp_zero, scale_mlp_zero) + x_m = self.t_mask_select(x_mask, x_m, x_m_zero, number_frames, number_pixel_patches) + + # MLP + x_m = self.mlp(x_m) + + # modulate MLP + x_m_s = gate_mlp * x_m + if x_mask is not None: + x_m_s_zero = gate_mlp_zero * x_m + x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, number_frames, number_pixel_patches) + + # residual + x = x + x_m_s + + return x + + def dynamic_switch(self, x, s, t, temporal_to_spatial: bool): + if temporal_to_spatial: + scatter_dim, gather_dim = 2, 1 + scatter_pad = get_spatial_pad() + gather_pad = get_temporal_pad() + else: + scatter_dim, gather_dim = 1, 2 + scatter_pad = get_temporal_pad() + gather_pad = get_spatial_pad() + + x = rearrange(x, "b (t s) c -> b t s c", t=t, s=s) + + x = all_to_all_with_pad( + x, + get_sequence_parallel_group(), + scatter_dim=scatter_dim, + gather_dim=gather_dim, + scatter_pad=scatter_pad, + gather_pad=gather_pad, + ) + + new_s, new_t = x.shape[2], x.shape[1] + + x = rearrange(x, "b t s c -> b (t s) c", t=new_t, s=new_s) + return x, new_s, new_t + + +class STDiT3(DiffusionModel): + config_class = STDiT3Config + weigths_name = 'model.safetensors' + + def __init__(self, config): + super().__init__(config) + + self.pred_sigma = True + self.in_channels = config.in_channels + self.out_channels = config.in_channels * 2 if self.pred_sigma else config.in_channels + + # model size related + self.depth = 28 + self.mlp_ratio = 4.0 + self.hidden_size = 1152 + self.num_heads = 16 + + # computation related + self.enable_flash_attn = config.enable_flash_attn + self.enable_sequence_parallelism = config.enable_sequence_parallelism + + # input size related + self.patch_size = (1, 2, 2) + self.input_sq_size = 512 + self.pos_embed = PositionEmbedding2D(self.hidden_size) + self.rope = RotaryEmbedding(dim=self.hidden_size // self.num_heads) + + # embedding + self._init_embedding(config) + + self._init_blocks(config) + + # final layer + self.final_layer = T2IFinalLayer(self.hidden_size, np.prod(self.patch_size), self.out_channels) + + self._initialize_weights() + + self._init_cache(config) + + def forward( + self, + x: torch.Tensor, + timestep: torch.Tensor, + y: torch.Tensor, + mask: torch.Tensor = None, + x_mask: torch.Tensor = None, + fps: torch.Tensor = None, + height: torch.Tensor = None, + width: torch.Tensor = None, + t_idx: int = 0, + **kwargs + ) -> torch.Tensor: + + dtype = self.x_embedder.proj.weight.dtype + x = x.to(dtype) + timestep = timestep.to(dtype) + y = y.to(dtype) + + if fps is None: + fps = torch.tensor([8], device=x.device) + if height is None: + height = torch.tensor([720], device=x.device) + if width is None: + width = torch.tensor([1280], device=x.device) + + # get shape + _, _, tx, hx, wx = x.size() + x_shape0_t, x_shape1_h, x_shape2_w = self._get_dynamic_size(x) + s_hw = x_shape1_h * x_shape2_w + resolution_sq = (height[0].to(torch.float32).item() * width[0].to(torch.float32).item()) ** 0.5 + scale = resolution_sq / self.input_sq_size + + # === get pos embed === + pos_emb = self.pos_embed(x, x_shape1_h, x_shape2_w, scale) + + # === get timestep embed === + t, t_mlp, t0, t0_mlp = self._get_t_embed(x, timestep, fps, x_mask) + + # === get y embed === + y, y_lens = self._get_y_embed(y, mask) + + # === get x embed === + x = self._get_x_embed(x, pos_emb, x_shape0_t, s_hw) + + x = rearrange(x, "b t s c -> b (t s) c", t=x_shape0_t, s=s_hw) + + x = self._forward_blocks(x, y, x_mask, x_shape0_t, s_hw, y_lens, t_mlp, t0_mlp, t_idx) + + # === final layer === + x = self.final_layer(x, t, x_mask, t0, (x_shape0_t, s_hw)) + x = self._unpatchify(x, x_shape0_t, x_shape1_h, x_shape2_w, (tx, hx, wx)) + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + def _init_embedding(self, config): + self.x_embedder = PatchEmbed3D(self.patch_size, config.in_channels, self.hidden_size) + self.t_embedder = TimestepEmbedder(self.hidden_size) + self.fps_embedder = SizeEmbedder(self.hidden_size) + self.t_block = nn.Sequential( + nn.SiLU(), + nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True), + ) + self.y_embedder = CaptionEmbedder( + in_channels=config.caption_channels, + hidden_size=self.hidden_size, + act_layer=approx_gelu, + token_num=300, + ) + + def _init_blocks(self, config): + # spatial blocks + self.spatial_blocks = nn.ModuleList( + [ + STDiT3Block( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qk_norm=True, + enable_flash_attn=config.enable_flash_attn, + enable_sequence_parallelism=config.enable_sequence_parallelism, + ) + for _ in range(self.depth) + ] + ) + + # temporal blocks + self.temporal_blocks = nn.ModuleList( + [ + STDiT3Block( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qk_norm=True, + enable_flash_attn=config.enable_flash_attn, + enable_sequence_parallelism=config.enable_sequence_parallelism, + temporal=True, + rope=self.rope.rotate_queries_or_keys, + ) + for _ in range(self.depth) + ] + ) + + def _init_cache(self, config): + self.use_cache = config.use_cache + self.cache_interval = config.cache_interval + self.cache_start = config.cache_start + self.num_cache_layer = config.num_cache_layer + self.cache_start_steps = config.cache_start_steps + + self.delta_cache = None + + def _initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize fps_embedder + nn.init.normal_(self.fps_embedder.mlp[0].weight, std=0.02) + nn.init.constant_(self.fps_embedder.mlp[0].bias, 0) + nn.init.constant_(self.fps_embedder.mlp[2].weight, 0) + nn.init.constant_(self.fps_embedder.mlp[2].bias, 0) + + # Initialize timporal blocks + for block in self.temporal_blocks: + nn.init.constant_(block.attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.mlp.fc2.weight, 0) + + def _get_dynamic_size(self, x): + _, _, x_shape2_t, x_shape3_h, x_shape4_w = x.size() + if x_shape2_t % self.patch_size[0] != 0: + x_shape2_t += self.patch_size[0] - x_shape2_t % self.patch_size[0] + if x_shape3_h % self.patch_size[1] != 0: + x_shape3_h += self.patch_size[1] - x_shape3_h % self.patch_size[1] + if x_shape4_w % self.patch_size[2] != 0: + x_shape4_w += self.patch_size[2] - x_shape4_w % self.patch_size[2] + x_shape2_t = x_shape2_t // self.patch_size[0] + x_shape3_h = x_shape3_h // self.patch_size[1] + x_shape4_w = x_shape4_w // self.patch_size[2] + return x_shape2_t, x_shape3_h, x_shape4_w + + def _get_t_embed(self, x, timestep, fps, x_mask): + x_shape0_b = x.size(0) + t = self.t_embedder(timestep, dtype=x.dtype) # [B, C] + fps = self.fps_embedder(fps.unsqueeze(1), x_shape0_b) + t = t + fps + t_mlp = self.t_block(t) + t0 = t0_mlp = None + if x_mask is not None: + t0_timestep = torch.zeros_like(timestep) + t0 = self.t_embedder(t0_timestep, dtype=x.dtype) + t0 = t0 + fps + t0_mlp = self.t_block(t0) + return t, t_mlp, t0, t0_mlp + + def _get_x_embed(self, x, pos_emb, t, s): + x = self.x_embedder(x) # [B, N, C] + x = rearrange(x, "b (t s) c -> b t s c", t=t, s=s) + x = x + pos_emb + return x + + def _get_y_embed(self, y, mask): + y, y_lens = self._encode_text(y, mask) + + return y, y_lens + + def _encode_text(self, y, mask=None): + y = self.y_embedder(y) # [B, 1, N_token, C] + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, self.hidden_size) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, self.hidden_size) + return y, y_lens + + # forward blocks in range [start_idx, end_idx), then return input and output + def _forward_blocks_range(self, x, y, x_mask, t, s, y_lens, t_mlp, t0_mlp, start_idx, end_idx): + for spatial_block, temporal_block in zip(self.spatial_blocks[start_idx: end_idx], + self.temporal_blocks[start_idx: end_idx]): + x = spatial_block(x, y, t_mlp, y_lens, x_mask, t0_mlp, t, s) + x = temporal_block(x, y, t_mlp, y_lens, x_mask, t0_mlp, t, s) + + return x + + def _forward_blocks(self, x, y, x_mask, t, s, y_lens, t_mlp, t0_mlp, t_idx): + # === if dsp parallel, split x across T === + if self.enable_sequence_parallelism: + set_temporal_pad(t) + set_spatial_pad(s) + x = rearrange(x, "b (t s) c -> b t s c", t=t, s=s) + x = split_sequence( + x, get_sequence_parallel_group(), dim=1, pad=get_temporal_pad() + ) + t = x.shape[1] + x = rearrange(x, "b t s c -> b (t s) c", t=t, s=s) + if x_mask is not None: + x_mask = split_sequence( + x_mask, get_sequence_parallel_group(), dim=1, pad=get_temporal_pad() + ) + + num_blocks = len(self.spatial_blocks) + if not self.use_cache or (t_idx < self.cache_start_steps): + x = self._forward_blocks_range(x, y, x_mask, t, s, y_lens, t_mlp, t0_mlp, + 0, num_blocks) + else: + # infer [0, cache_start) + x = self._forward_blocks_range(x, y, x_mask, t, s, y_lens, t_mlp, t0_mlp, + 0, self.cache_start) + # infer [cache_start, cache_end) + cache_end = np.minimum(self.cache_start + self.num_cache_layer, num_blocks) + x_before_cache = x.clone() + if t_idx % self.cache_interval == (self.cache_start_steps % 2): + x = self._forward_blocks_range(x, y, x_mask, t, s, y_lens, t_mlp, t0_mlp, + self.cache_start, cache_end) + self.delta_cache = x - x_before_cache + else: + x = x_before_cache + self.delta_cache + # infer [cache_end, num_blocks) + x = self._forward_blocks_range(x, y, x_mask, t, s, y_lens, t_mlp, t0_mlp, + cache_end, num_blocks) + + # === if dsp parallel, gather x across T === + if self.enable_sequence_parallelism: + x = rearrange(x, "b (t s) c -> b t s c", t=t, s=s) + x = gather_sequence( + x, get_sequence_parallel_group(), dim=1, pad=get_temporal_pad() + ) + t, s = x.shape[1], x.shape[2] + x = rearrange(x, "b t s c -> b (t s) c", t=t, s=s) + + return x + + def _unpatchify(self, x, n_t, n_h, n_w, shape_org): + """ + Args: + x (torch.Tensor): of shape [B, N, C] + + Return: + x (torch.Tensor): of shape [B, C_out, T, H, W] + """ + r_t, r_h, r_w = shape_org + t_p, h_p, w_p = self.patch_size + + x_shape0_b = x.shape[0] + x = x.reshape(x_shape0_b, n_t, n_h, n_w, t_p, h_p, w_p, self.out_channels) + x = x.permute(0, 7, 1, 4, 2, 5, 3, 6) + x = x.reshape(x_shape0_b, self.out_channels, n_t * t_p, n_h * h_p, n_w * w_p) + + # unpad + x = x[:, :, :r_t, :r_h, :r_w] + return x diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/__init__.py new file mode 100644 index 0000000000..ffb91659f0 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/__init__.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .utils import ( + set_random_seed, append_score_to_prompts, extract_prompts_loop, + merge_prompt, prepare_multi_resolution_info, split_prompt, is_npu_available, exists, default +) +from .patch_utils import ( + Patchify, Depatchify +) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/patch_utils.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/patch_utils.py new file mode 100644 index 0000000000..84805f4eee --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/patch_utils.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.distributed as dist + + +class Patchify(nn.Module): + def __init__(self): + super().__init__() + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + + def forward(self, hidden_state, dim, is_overlap): + length = hidden_state.shape[dim] + if is_overlap: + overlap = self.rank % 2 + start_idx = (length + self.world_size - 1) // self.world_size * self.rank - overlap + end_idx = min((length + self.world_size - 1) // self. world_size * (self.rank + 1) - overlap + 1, length) + else: + start_idx = (length + self.world_size - 1) // self.world_size * self.rank + end_idx = min((length + self.world_size - 1) // self.world_size * (self.rank + 1), length) + idx = torch.arange(start_idx, end_idx, device=f"npu:{self.rank}") + return hidden_state.index_select(dim, idx).clone() + + +class Depatchify(nn.Module): + def __init__(self): + super().__init__() + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + + def forward(self, patch_hidden_state, dim, is_overlap): + if is_overlap: + overlap = self.rank % 2 + start_idx = overlap + end_idx = patch_hidden_state.shape[dim] + overlap - 1 + idx = torch.arange(start_idx, end_idx, device=f"npu:{self.rank}") + patch_hidden_state = patch_hidden_state.index_select(dim, idx) + + patch_length_list = [torch.empty([1], dtype=torch.int64, device=f"npu:{self.rank}") + for _ in range(self.world_size)] + dist.all_gather( + patch_length_list, + torch.tensor( + [patch_hidden_state.shape[dim]], + dtype=torch.int64, + device=f"npu:{self.rank}" + ) + ) + patch_shape = list(patch_hidden_state.shape) + patch_hidden_state_list = [] + for i in range(self.world_size): + patch_shape[dim] = patch_length_list[i].item() + patch_hidden_state_list.append( + torch.empty(tuple(patch_shape), dtype=patch_hidden_state.dtype, device=f"npu:{self.rank}")) + dist.all_gather( + patch_hidden_state_list, + patch_hidden_state.contiguous() + ) + + return torch.cat(patch_hidden_state_list, dim) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/utils.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/utils.py new file mode 100644 index 0000000000..c324585f37 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/utils/utils.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib +import random +import torch +import numpy as np + +IMG_FPS = 8 + + +def is_npu_available(): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + return False + + import torch_npu + + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def set_random_seed(seed): + """Set random seed. + + Args: + seed (int, optional): Seed to be used. + + """ + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + return seed + + +def prepare_multi_resolution_info(info_type, video_property, device, dtype): + (batch_size, image_size, num_frames, fps) = video_property + if info_type is None: + return dict() + elif info_type == "PixArtMS": + hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(batch_size, 1) + ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(batch_size, 1) + return dict(ar=ar, hw=hw) + elif info_type in ["STDiT2", "OpenSora"]: + fps = fps if num_frames > 1 else IMG_FPS + fps = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size) + height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(batch_size) + width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(batch_size) + num_frames = torch.tensor([num_frames], device=device, dtype=dtype).repeat(batch_size) + ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(batch_size) + return dict(height=height, width=width, num_frames=num_frames, ar=ar, fps=fps) + else: + raise NotImplementedError + + +def extract_prompts_loop(prompts, num_loop=0): + ret_prompts = [] + for prompt in prompts: + if prompt.startswith("|0|"): + prompt_list = prompt.split("|")[1:] + text_list = [] + for i in range(0, len(prompt_list), 2): + start_loop = int(prompt_list[i]) + text = prompt_list[i + 1] + end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop + 1 + text_list.extend([text] * (end_loop - start_loop)) + prompt = text_list[num_loop] + ret_prompts.append(prompt) + return ret_prompts + + +def split_prompt(prompt_text): + if prompt_text.startswith("|0|"): + # this is for prompts which look like + # |0| a beautiful day |1| a sunny day |2| a rainy day + # we want to parse it into a list of prompts with the loop index + prompt_list = prompt_text.split("|")[1:] + text_list = [] + loop_idx = [] + for i in range(0, len(prompt_list), 2): + start_loop = int(prompt_list[i]) + text = prompt_list[i + 1].strip() + text_list.append(text) + loop_idx.append(start_loop) + return text_list, loop_idx + else: + return_value = None + return [prompt_text], return_value + + +def merge_prompt(text_list, loop_idx_list=None): + if loop_idx_list is None: + return text_list[0] + else: + prompt = "" + for i, text in enumerate(text_list): + prompt += f"|{loop_idx_list[i]}|{text}" + return prompt + + +def append_score_to_prompts(prompts, aes=None, flow=None, camera_motion=None): + new_prompts = [] + for prompt in prompts: + new_prompt = prompt + if aes is not None and "aesthetic score:" not in prompt: + new_prompt = f"{new_prompt} aesthetic score: {aes:.1f}." + if flow is not None and "motion score:" not in prompt: + new_prompt = f"{new_prompt} motion score: {flow:.1f}." + if camera_motion is not None and "camera motion:" not in prompt: + new_prompt = f"{new_prompt} camera motion: {camera_motion}." + new_prompts.append(new_prompt) + return new_prompts + diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/VideoAutoencoder.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/VideoAutoencoder.py new file mode 100644 index 0000000000..2db644efc6 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/VideoAutoencoder.py @@ -0,0 +1,199 @@ +import os + +import torch +import torch.nn as nn +from mindiesd.config_utils import ConfigMixin +from mindiesd.models.model_utils import DiffusionModel +from diffusers.models import AutoencoderKL + +from ..layer import PatchConv3d, Conv3dAdapter +from ..layer import PatchGroupNorm3d, GroupNorm3dAdapter +from ..layer import rearrange_flatten_t, rearrange_unflatten_t + +from .vae_temporal import vae_temporal_sd +from ..utils import Patchify, Depatchify + + +class VideoAutoencoderConfig(ConfigMixin): + config_name = 'config.json' + + def __init__( + self, + from_pretrained, + set_patch_parallel=False, + **kwargs, + ): + from_pretrained = os.path.join(from_pretrained, "vae_2d") + vae_2d = dict(from_pretrained=from_pretrained, + subfolder="vae", + micro_batch_size=4) + self.vae_2d = vae_2d + self.freeze_vae_2d = False + self.micro_frame_size = 17 + + self.shift = (-0.10, 0.34, 0.27, 0.98) + self.scale = (3.85, 2.32, 2.33, 3.06) + + self.set_patch_parallel = set_patch_parallel + + super().__init__(**kwargs) + + +class VideoAutoencoderKL(nn.Module): + def __init__( + self, from_pretrained, micro_batch_size=None, cache_dir=None, subfolder=None + ): + super().__init__() + + path_check = os.path.join(from_pretrained, subfolder) + + self.module = AutoencoderKL.from_pretrained( + from_pretrained, + cache_dir=cache_dir, + local_files_only=True, + subfolder=subfolder, + ) + self.out_channels = self.module.config.latent_channels + self.patch_size = (1, 8, 8) + self.micro_batch_size = micro_batch_size + + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + + def encode(self, x): + # x shape is : (B, C, T, H, W) + x_shape0_b = x.shape[0] + x = rearrange_flatten_t(x) + + if self.micro_batch_size is None: + x = self.module.encode(x).latent_dist.sample().mul_(0.18215) + else: + # NOTE: cannot be used for training + bs = self.micro_batch_size + x_out = [] + for i in range(0, x.shape[0], bs): + x_bs = x[i: i + bs] + x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215) + x_out.append(x_bs) + x = torch.cat(x_out, dim=0) + x = rearrange_unflatten_t(x, b=x_shape0_b) + return x + + def decode(self, x, **kwargs): + # x shape is : (B, C, T, H, W) + x_shape0_b = x.shape[0] + x = rearrange_flatten_t(x) + + if self.micro_batch_size is None: + x = self.module.decode(x / 0.18215).sample + else: + # NOTE: cannot be used for training + bs = self.micro_batch_size + x_out = [] + for i in range(0, x.shape[0], bs): + x_bs = x[i: i + bs] + x_bs = self.module.decode(x_bs / 0.18215).sample + x_out.append(x_bs) + x = torch.cat(x_out, dim=0) + x = rearrange_unflatten_t(x, b=x_shape0_b) + return x + + def get_latent_size(self, input_size): + latent_size = [] + for i in range(3): + latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None) + return latent_size + + +class VideoAutoencoder(DiffusionModel): + config_class = VideoAutoencoderConfig + + weigths_name = 'model.safetensors' + + def __init__(self, config: VideoAutoencoderConfig): + super().__init__(config=config) + + self.set_patch_parallel = config.set_patch_parallel + self.spatial_vae = VideoAutoencoderKL(**config.vae_2d) + self.spatial_vae.to("npu") + self.temporal_vae = vae_temporal_sd() + self.temporal_vae.to("npu") + + self.micro_frame_size = config.micro_frame_size + self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0] + + if config.freeze_vae_2d: + for param in self.spatial_vae.parameters(): + param.requires_grad = False + + self.out_channels = self.temporal_vae.out_channels + + # normalization parameters + scale = torch.tensor(config.scale) + shift = torch.tensor(config.shift) + if len(scale.shape) > 0: + scale = scale[None, :, None, None, None] + if len(shift.shape) > 0: + shift = shift[None, :, None, None, None] + self.register_buffer("scale", scale) + self.register_buffer("shift", shift) + + # Patchify and DePatchify + if self.set_patch_parallel: + self.patchify = Patchify() + self.depatchify = Depatchify() + + def get_latent_size(self, input_size): + if self.micro_frame_size is None or input_size[0] is None: + return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size)) + else: + sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]] + sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size)) + sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size) + remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None] + if remain_temporal_size[0] > 0: + remain_size = self.temporal_vae.get_latent_size(remain_temporal_size) + sub_latent_size[0] += remain_size[0] + return sub_latent_size + + def decode(self, z, num_frames): + if self.set_patch_parallel: + for _, module in self.temporal_vae.named_modules(): + if isinstance(module, PatchConv3d) or isinstance(module, PatchGroupNorm3d): + continue + for subname, submodule in module.named_children(): + if isinstance(submodule, nn.Conv3d): + wrapped_submodule = Conv3dAdapter(submodule, isinstance(module, CausalConv3d)) + setattr(module, subname, wrapped_submodule) + elif isinstance(submodule, nn.GroupNorm): + wrapped_submodule = GroupNorm3dAdapter(submodule) + setattr(module, subname, wrapped_submodule) + + z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype) + + if self.set_patch_parallel: + z_patch = self.patchify(z, dim=-1, is_overlap=True) + x_z_patch = self.temporal_vae.decode(z_patch, num_frames=num_frames) + x_z = self.depatchify(x_z_patch, dim=-1, is_overlap=True) + x_z_patch = self.patchify(x_z, dim=-3, is_overlap=True) + x_patch = self.spatial_vae.decode(x_z_patch) + x = self.depatchify(x_patch, dim=-3, is_overlap=True) + elif self.micro_frame_size is None: + x_z = self.temporal_vae.decode(z, num_frames=num_frames) + x = self.spatial_vae.decode(x_z) + else: + x_z_list = [] + for i in range(0, z.size(2), self.micro_z_frame_size): + z_bs = z[:, :, i: i + self.micro_z_frame_size] + x_z_bs = self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames)) + x_z_list.append(x_z_bs) + num_frames -= self.micro_frame_size + x_z = torch.cat(x_z_list, dim=2) + x = self.spatial_vae.decode(x_z) + + return x \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/__init__.py new file mode 100644 index 0000000000..ffcdbe1e74 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .VideoAutoencoder import (VideoAutoencoder, VideoAutoencoderConfig) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/vae_temporal.py b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/vae_temporal.py new file mode 100644 index 0000000000..13e246c904 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/opensora/vae/vae_temporal.py @@ -0,0 +1,472 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DiagonalGaussianDistribution(object): + def __init__( + self, + parameters, + deterministic=False, + ): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device, dtype=self.mean.dtype) + + def sample(self): + # torch.randn: standard normal distribution + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device, dtype=self.mean.dtype) + return x + + def mode(self): + return self.mean + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def divisible_by(num, den): + return (num % den) == 0 + + +def is_odd(n): + return not divisible_by(n, 2) + + +def pad_at_dim(t, pad, dim=-1): + dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = (0, 0) * dims_from_right + return F.pad(t, (*zeros, *pad), mode="constant") + + +def exists(v): + return v is not None + + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + pad_mode="constant", + strides=None, # allow custom stride + **kwargs, + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + dilation = kwargs.pop("dilation", 1) + stride = strides[0] if strides is not None else kwargs.pop("stride", 1) + + self.pad_mode = pad_mode + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + + stride = strides if strides is not None else (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + def __init__( + self, + in_channels, # SCH: added + filters, + conv_fn, + norm_fn, + activation_fn=nn.SiLU, + use_conv_shortcut=False, + num_groups=32, + ): + super().__init__() + self.in_channels = in_channels + self.filters = filters + self.activate = activation_fn() + self.use_conv_shortcut = use_conv_shortcut + + # SCH: MAGVIT uses GroupNorm by default + self.norm1 = norm_fn(num_groups, in_channels) + self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False) + self.norm2 = norm_fn(num_groups, self.filters) + self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False) + if in_channels != filters: + if self.use_conv_shortcut: + self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False) + else: + self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False) + + def forward(self, x): + residual = x + x = self.norm1(x) + x = self.activate(x) + x = self.conv1(x) + x = self.norm2(x) + x = self.activate(x) + x = self.conv2(x) + if self.in_channels != self.filters: # SCH: ResBlock X->Y + residual = self.conv3(residual) + return x + residual + + +def get_activation_fn(activation): + if activation == "relu": + activation_fn = nn.ReLU + elif activation == "swish": + activation_fn = nn.SiLU + else: + raise NotImplementedError + return activation_fn + + +class Encoder(nn.Module): + """Encoder Blocks.""" + + def __init__(self, in_out_channels=4, latent_embed_dim=512, # num channels for latent vector + filters=128, num_res_blocks=4, channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(False, True, True), + num_groups=32, # for nn.GroupNorm + activation_fn="swish", + ): + super().__init__() + self.filters = filters + self.num_res_blocks = num_res_blocks + self.num_blocks = len(channel_multipliers) + self.channel_multipliers = channel_multipliers + self.temporal_downsample = temporal_downsample + self.num_groups = num_groups + self.embedding_dim = latent_embed_dim + + self.activation_fn = get_activation_fn(activation_fn) + self.activate = self.activation_fn() + self.conv_fn = CausalConv3d + self.norm_fn = nn.GroupNorm + self.block_args = dict( + conv_fn=self.conv_fn, + norm_fn=self.norm_fn, + activation_fn=self.activation_fn, + use_conv_shortcut=False, + num_groups=self.num_groups, + ) + + # first layer conv + self._init_first_layer(in_out_channels) + + # ResBlocks and conv downsample + prev_filters, filters = self._init_res_conv_layer() + + # # last layer res block + self._init_last_layer(prev_filters, filters) + + def forward(self, x): + x = self.conv_in(x) + + for i in range(self.num_blocks): + for j in range(self.num_res_blocks): + x = self.block_res_blocks[i][j](x) + if i < self.num_blocks - 1: + x = self.conv_blocks[i](x) + for i in range(self.num_res_blocks): + x = self.res_blocks[i](x) + + x = self.norm1(x) + x = self.activate(x) + x = self.conv2(x) + return x + + def _init_first_layer(self, in_out_channels): + self.conv_in = self.conv_fn( + in_out_channels, + self.filters, + kernel_size=(3, 3, 3), + bias=False, + ) + + def _init_res_conv_layer(self): + self.block_res_blocks = nn.ModuleList([]) + self.conv_blocks = nn.ModuleList([]) + + filters = self.filters + prev_filters = filters # record for in_channels + for i in range(self.num_blocks): + filters = self.filters * self.channel_multipliers[i] + block_items = nn.ModuleList([]) + for _ in range(self.num_res_blocks): + block_items.append(ResBlock(prev_filters, filters, **self.block_args)) + prev_filters = filters # update in_channels + self.block_res_blocks.append(block_items) + + if i < self.num_blocks - 1: + if self.temporal_downsample[i]: + t_stride = 2 if self.temporal_downsample[i] else 1 + s_stride = 1 + self.conv_blocks.append( + self.conv_fn( + prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride) + ) + ) + prev_filters = filters # update in_channels + else: + # if no t downsample, don't add since this does nothing for pipeline models + self.conv_blocks.append(nn.Identity(prev_filters)) # Identity + prev_filters = filters # update in_channels + return prev_filters, filters + + def _init_last_layer(self, prev_filters, filters): + # last layer res block + self.res_blocks = nn.ModuleList([]) + for _ in range(self.num_res_blocks): + self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args)) + prev_filters = filters # update in_channels + + # MAGVIT uses Group Normalization + self.norm1 = self.norm_fn(self.num_groups, prev_filters) + + self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same") + + +class Decoder(nn.Module): + """Decoder Blocks.""" + + def __init__( + self, + in_out_channels=4, + latent_embed_dim=512, + filters=128, + num_res_blocks=4, + channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(False, True, True), + num_groups=32, # for nn.GroupNorm + activation_fn="swish", + ): + super().__init__() + self.in_out_channels = in_out_channels + self.filters = filters + self.num_res_blocks = num_res_blocks + self.num_blocks = len(channel_multipliers) + self.channel_multipliers = channel_multipliers + self.temporal_downsample = temporal_downsample + self.num_groups = num_groups + self.embedding_dim = latent_embed_dim + self.s_stride = 1 + + self.activation_fn = get_activation_fn(activation_fn) + self.activate = self.activation_fn() + self.conv_fn = CausalConv3d + self.norm_fn = nn.GroupNorm + self.block_args = dict( + conv_fn=self.conv_fn, + norm_fn=self.norm_fn, + activation_fn=self.activation_fn, + use_conv_shortcut=False, + num_groups=self.num_groups, + ) + self._init_layers() + + def forward(self, x): + x = self.conv1(x) + for i in range(self.num_res_blocks): + x = self.res_blocks[i](x) + for i in reversed(range(self.num_blocks)): + for j in range(self.num_res_blocks): + x = self.block_res_blocks[i][j](x) + if i > 0: + t_stride = 2 if self.temporal_downsample[i - 1] else 1 + x = self.conv_blocks[i - 1](x) + + x = self.rearrange_with_reshape(x, t_stride, self.s_stride, self.s_stride) + x = self.norm1(x) + x = self.activate(x) + x = self.conv_out(x) + return x + + def rearrange_with_reshape(self, x, ts, hs, ws): + x_shape0_b, c_ts_hs_ws, x_shape2_t, x_shape3_h, x_shape4_w = x.shape + c = c_ts_hs_ws // (ts * hs * ws) + + x = x.reshape(x_shape0_b, c, ts, hs, ws, x_shape2_t, x_shape3_h, x_shape4_w) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(x_shape0_b, c, x_shape2_t * ts, x_shape3_h * hs, x_shape4_w * ws) + return x + + def _init_layers(self): + filters = self.filters * self.channel_multipliers[-1] + prev_filters = filters + + # last conv + self.conv1 = self.conv_fn(self.embedding_dim, filters, kernel_size=(3, 3, 3), bias=True) + + # last layer res block + self.res_blocks = nn.ModuleList([]) + for _ in range(self.num_res_blocks): + self.res_blocks.append(ResBlock(filters, filters, **self.block_args)) + + # ResBlocks and conv upsample + self.block_res_blocks = nn.ModuleList([]) + self.num_blocks = len(self.channel_multipliers) + self.conv_blocks = nn.ModuleList([]) + # reverse to keep track of the in_channels, but append also in a reverse direction + for i in reversed(range(self.num_blocks)): + filters = self.filters * self.channel_multipliers[i] + # resblock handling + block_items = nn.ModuleList([]) + for _ in range(self.num_res_blocks): + block_items.append(ResBlock(prev_filters, filters, **self.block_args)) + prev_filters = filters # SCH: update in_channels + self.block_res_blocks.insert(0, block_items) # SCH: append in front + + # conv blocks with upsampling + if i > 0: + if self.temporal_downsample[i - 1]: + t_stride = 2 if self.temporal_downsample[i - 1] else 1 + # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 + self.conv_blocks.insert( + 0, + self.conv_fn( + prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, + kernel_size=(3, 3, 3) + ), + ) + else: + self.conv_blocks.insert( + 0, + nn.Identity(prev_filters), + ) + + self.norm1 = self.norm_fn(self.num_groups, prev_filters) + self.conv_out = self.conv_fn(filters, self.in_out_channels, 3) + + +class VAETemporal(nn.Module): + def __init__(self, in_out_channels=4, latent_embed_dim=4, embed_dim=4, filters=128, num_res_blocks=4, + channel_multipliers=(1, 2, 2, 4), temporal_downsample=(True, True, False), + num_groups=32, # for nn.GroupNorm + activation_fn="swish", + ): + super().__init__() + + self.time_downsample_factor = 2 ** sum(temporal_downsample) + self.patch_size = (self.time_downsample_factor, 1, 1) + self.out_channels = in_out_channels + + # NOTE: following MAGVIT, conv in bias=False in encoder first conv + self.encoder = Encoder( + in_out_channels=in_out_channels, + latent_embed_dim=latent_embed_dim * 2, + filters=filters, + num_res_blocks=num_res_blocks, + channel_multipliers=channel_multipliers, + temporal_downsample=temporal_downsample, + num_groups=num_groups, # for nn.GroupNorm + activation_fn=activation_fn, + ) + self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1) + + self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1) + self.decoder = Decoder( + in_out_channels=in_out_channels, + latent_embed_dim=latent_embed_dim, + filters=filters, + num_res_blocks=num_res_blocks, + channel_multipliers=channel_multipliers, + temporal_downsample=temporal_downsample, + num_groups=num_groups, # for nn.GroupNorm + activation_fn=activation_fn, + ) + + def get_latent_size(self, input_size): + latent_size = [] + for i in range(3): + if input_size[i] is None: + lsize = None + elif i == 0: + time_padding = ( + 0 + if (input_size[i] % self.time_downsample_factor == 0) + else self.time_downsample_factor - input_size[i] % self.time_downsample_factor + ) + lsize = (input_size[i] + time_padding) // self.patch_size[i] + else: + lsize = input_size[i] // self.patch_size[i] + latent_size.append(lsize) + return latent_size + + def encode(self, x): + time_padding = ( + 0 + if (x.shape[2] % self.time_downsample_factor == 0) + else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor + ) + x = pad_at_dim(x, (time_padding, 0), dim=2) + encoded_feature = self.encoder(x) + moments = self.quant_conv(encoded_feature).to(x.dtype) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, num_frames=None): + time_padding = ( + 0 + if (num_frames % self.time_downsample_factor == 0) + else self.time_downsample_factor - num_frames % self.time_downsample_factor + ) + z = self.post_quant_conv(z) + x = self.decoder(z) + x = x[:, :, time_padding:] + return x + + def forward(self, x, sample_posterior=True): + posterior = self.encode(x) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + recon_video = self.decode(z, num_frames=x.shape[2]) + return recon_video, posterior, z + + +def vae_temporal_sd(**kwargs): + model = VAETemporal( + in_out_channels=4, + latent_embed_dim=4, + embed_dim=4, + filters=128, + num_res_blocks=4, + channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(False, True, True), + **kwargs, + ) + return model diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/prompts/t2v_sora.txt b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/prompts/t2v_sora.txt new file mode 100644 index 0000000000..6b73dea0fc --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/prompts/t2v_sora.txt @@ -0,0 +1,48 @@ +A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. +Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. +A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors. +Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway. +Animated scene features a close-up of a short fluffy monster kneeling beside a melting red candle. The art style is 3D and realistic, with a focus on lighting and texture. The mood of the painting is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and open mouth. Its pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time. The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image. +A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures. +This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage and red chest. Its crest is made of delicate, lacy feathers, while its eye is a striking red color. The bird’s head is tilted slightly to the side, giving the impression of it looking regal and majestic. The background is blurred, drawing attention to the bird’s striking appearance. +Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee. +A young man at his 20s is sitting on a piece of cloud in the sky, reading a book. +Historical footage of California during the gold rush. +A close up view of a glass sphere that has a zen garden within it. There is a small dwarf in the sphere who is raking the zen garden and creating patterns in the sand. +Extreme close up of a 24 year old woman’s eye blinking, standing in Marrakech during magic hour, cinematic film shot in 70mm, depth of field, vivid colors, cinematic +A cartoon kangaroo disco dances. +A beautiful homemade video showing the people of Lagos, Nigeria in the year 2056. Shot with a mobile phone camera. +A petri dish with a bamboo forest growing within it that has tiny red pandas running around. +The camera rotates around a large stack of vintage televisions all showing different programs — 1950s sci-fi movies, horror movies, news, static, a 1970s sitcom, etc, set inside a large New York museum gallery. +3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream, its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. The creature looks up in awe at a large, glowing tree that seems to be the heart of the forest. +The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it’s tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds. +Reflections in the window of a train traveling through the Tokyo suburbs. +A drone camera circles around a beautiful historic church built on a rocky outcropping along the Amalfi Coast, the view showcases historic and magnificent architectural details and tiered pathways and patios, waves are seen crashing against the rocks below as the view overlooks the horizon of the coastal waters and hilly landscapes of the Amalfi Coast Italy, several distant people are seen walking and enjoying vistas on patios of the dramatic ocean views, the warm glow of the afternoon sun creates a magical and romantic feeling to the scene, the view is stunning captured with beautiful photography. +A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect. +A flock of paper airplanes flutters through a dense jungle, weaving around trees as if they were migrating birds. +A cat waking up its sleeping owner demanding breakfast. The owner tries to ignore the cat, but the cat tries new tactics and finally the owner pulls out a secret stash of treats from under the pillow to hold the cat off a little longer. +Borneo wildlife on the Kinabatangan River +A Chinese Lunar New Year celebration video with Chinese Dragon. +Tour of an art gallery with many beautiful works of art in different styles. +Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes. +A stop motion animation of a flower growing out of the windowsill of a suburban house. +The story of a robot’s life in a cyberpunk setting. +An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film. +A beautiful silhouette animation shows a wolf howling at the moon, feeling lonely, until it finds its pack. +New York City submerged like Atlantis. Fish, whales, sea turtles and sharks swim through the streets of New York. +A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in. +Step-printing scene of a person running, cinematic film shot in 35mm. +Five gray wolf pups frolicking and chasing each other around a remote gravel road, surrounded by grass. The pups run and leap, chasing each other, and nipping at each other, playing. +Basketball through hoop then explodes. +Archeologists discover a generic plastic chair in the desert, excavating and dusting it with great care. +A grandmother with neatly combed grey hair stands behind a colorful birthday cake with numerous candles at a wood dining room table, expression is one of pure joy and happiness, with a happy glow in her eye. She leans forward and blows out the candles with a gentle puff, the cake has pink frosting and sprinkles and the candles cease to flicker, the grandmother wears a light blue blouse adorned with floral patterns, several happy friends and family sitting at the table can be seen celebrating, out of focus. The scene is beautifully captured, cinematic, showing a 3/4 view of the grandmother and the dining room. Warm color tones and soft lighting enhance the mood. +The camera directly faces colorful buildings in Burano Italy. An adorable dalmation looks through a window on a building on the ground floor. Many people are walking and cycling along the canal streets in front of the buildings. +An adorable happy otter confidently stands on a surfboard wearing a yellow lifejacket, riding along turquoise tropical waters near lush tropical islands, 3D digital render art style. +This close-up shot of a chameleon showcases its striking color changing capabilities. The background is blurred, drawing attention to the animal’s striking appearance. +A corgi vlogging itself in tropical Maui. +A white and orange tabby cat is seen happily darting through a dense garden, as if chasing something. Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves as it walks. The path is narrow as it makes its way between all the plants. the scene is captured from a ground-level angle, following the cat closely, giving a low and intimate perspective. The image is cinematic with warm tones and a grainy texture. The scattered daylight between the leaves and plants above creates a warm contrast, accentuating the cat’s orange fur. The shot is clear and sharp, with a shallow depth of field. +Aerial view of Santorini during the blue hour, showcasing the stunning architecture of white Cycladic buildings with blue domes. The caldera views are breathtaking, and the lighting creates a beautiful, serene atmosphere. +Tiltshift of a construction site filled with workers, equipment, and heavy machinery. +A giant, towering cloud in the shape of a man looms over the earth. The cloud man shoots lighting bolts down to the earth. +A Samoyed and a Golden Retriever dog are playfully romping through a futuristic neon city at night. The neon lights emitted from the nearby buildings glistens off of their fur. +The Glenfinnan Viaduct is a historic railway bridge in Scotland, UK, that crosses over the west highland line between the towns of Mallaig and Fort William. It is a stunning sight as a steam train leaves the bridge, traveling over the arch-covered viaduct. The landscape is dotted with lush greenery and rocky mountains, creating a picturesque backdrop for the train journey. The sky is blue and the sun is shining, making for a beautiful day to explore this majestic spot. \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/requirents.txt b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/requirents.txt new file mode 100644 index 0000000000..3d3a438158 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/opensora1.2/requirents.txt @@ -0,0 +1,16 @@ +colossalai==0.3.7 +setuptools==57.5.0 +torch==2.1.0 +diffusers==0.26.3 +transformers==4.44.2 +open_clip_torch==2.20.0 +av==12.0.0 +tqdm==4.66.1 +timm==0.9.12 +tensorboard==2.11.0 +pre-commit==3.8.0 +mmengine==0.10.4 +ftfy==6.1.3 +accelerate==0.26.1 +bs4 +torchvision==0.16.0 \ No newline at end of file diff --git a/MindIE/MultiModal/CogVideoX/README.md b/MindIE/MultiModal/CogVideoX/README.md new file mode 100644 index 0000000000..d37c419df5 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/README.md @@ -0,0 +1,181 @@ +--- +license: apache-2.0 +frameworks: + - PyTorch +language: + - en +hardwares: + - NPU +--- +## 一、准备运行环境 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | + +### 1.1 获取CANN&MindIE安装包&环境准备 +- 设备支持: +Atlas 800I A2/Atlas 800T A2设备:支持的卡数最小为1 +- [Atlas 800I A2/Atlas 800T A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +### 1.2 CANN安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +### 1.3 MindIE安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +### 1.4 Torch_npu安装 +安装pytorch框架 版本2.1.0 +[安装包下载](https://download.pytorch.org/whl/cpu/torch/) + +使用pip安装 +```shell +# {version}表示软件版本号,{arch}表示CPU架构。 +pip install torch-${version}-cp310-cp310-linux_${arch}.whl +``` +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```shell +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` + +### 1.5 安装所需依赖。 +```shell +pip3 install -r requirements.txt +``` + +## 二、下载本仓库 + +### 2.1 下载到本地 +```shell + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +``` + +## 三、CogVideoX-5b / CogVideoX-2b使用 + +### 3.1 权重及配置文件说明 +1. 下载CogVideoX-5b / CogVideoX-2b权重:(scheduler、text_encoder、tokenizer、transformer、vae,5个模型的配置文件及权重) +```shell + git clone https://huggingface.co/THUDM/CogVideoX-5b + git clone https://huggingface.co/THUDM/CogVideoX-2b +``` +2. 各模型的配置文件、权重文件的层级样例如下所示。 +```commandline +|----CogVideoX-5b / CogVideoX-2b +| |---- model_index.json +| |---- scheduler +| | |---- scheduler_config.json +| |---- text_encoder +| | |---- config.json +| | |---- 模型权重 +| |---- tokenizer +| | |---- config.json +| | |---- 模型权重 +| |---- transformer +| | |---- config.json +| | |---- 模型权重 +| |---- vae +| | |---- config.json +| | |---- 模型权重 +``` + +### 3.2 RoPE算子编译 +进入算子路径,执行编译命令 +```shell +cd pta_plugin +bash build.sh +``` +编译成功后会在build文件夹下生成.so结尾的算子文件 + + + +在cogvideox_5b/models/attention_processor.py脚本中添加编译生成的算子路径 +```python +torch.ops.load_library("./pta_plugin/build/libPTAExtensionOPS.so") +``` +注意:首次运行需要加载RoPE算子,请在正式推理前进行warmup + +### 3.3 单卡单prompt功能测试 +1. 设置CogVideoX-5b权重路径: +```shell +model_path='data/CogVideoX-5b' +``` + +或者设置CogVideoX-2b权重路径: +```shell +model_path='data/CogVideoX-5b' +``` + +2. 执行命令: +```shell +export CPU_AFFINITY_CONF=1 +export HCCL_OP_EXPANSION_MODE="AIV" +TASK_QUEUE_ENABLE=2 ASCEND_RT_VISIBLE_DEVICES=0 torchrun --master_port=2002 --nproc_per_node=1 inference.py\ + --prompt "A dog" \ + --model_path ${model_path} \ + --num_frames 48 \ + --width 720 \ + --height 480 \ + --fps 8 \ + --num_inference_steps 50 \ + --dtype bfloat16 +``` +参数说明: +- CPU_AFFINITY_CONF=1:环境变量,绑核。 +- HCCL_OP_EXPANSION_MODE="AIV":环境变量,通信算子编排。 +- TASK_QUEUE_ENABLE=2:开启二级流水。 +- ASCEND_RT_VISIBLE_DEVICES=0:device id,可设置其他卡数。 +- prompt:用于视频生成的文字描述提示。 +- model_path:权重路径,包含scheduler、text_encoder、tokenizer、transformer、vae,5个模型的配置文件及权重。 +- num_frames:生成视频的帧数。 +- width:生成视频的分辨率,宽。 +- height:生成视频的分辨率,高。 +- fps:生成视频的帧率,默认值为8。 +- num_inference_steps:推理迭代步数,默认值为50。 +- dtype:数据类型,默认值为bfloat16,可设置为float16,需要在命令前加INF_NAN_MODE_FORCE_DISABLE=1,开启饱和模式避免数值溢出。 + + +## 四、推理性能结果参考 +### CogVideoX-5b +| 硬件形态 | cpu规格 | batch size | 迭代次数 | 平均耗时 | +| :------: | :------: | :------: |:----:| :------: | +| Atlas 800I A2(8*64G) | 64核(arm) | 1 | 50 | 240s | + +### CogVideoX-2b +| 硬件形态 | cpu规格 | batch size | 迭代次数 | 平均耗时 | +| :------: | :------: | :------: |:----:| :------: | +| Atlas 800I A2(8*64G) | 64核(arm) | 1 | 50 | 102s | + +性能测试需要独占npu和cpu \ No newline at end of file diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/__init__.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/__init__.py new file mode 100644 index 0000000000..0428c592e8 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/__init__.py @@ -0,0 +1,4 @@ +from .pipelines import CogVideoXPipeline +from .models import CogVideoXTransformer3DModel +from .utils import get_world_size, get_rank, all_gather +from .utils import get_sp_world_size, get_sp_rank, get_dp_rank, get_dp_world_size, get_sp_group, get_dp_group \ No newline at end of file diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/__init__.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/__init__.py new file mode 100644 index 0000000000..a267e101cd --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/__init__.py @@ -0,0 +1 @@ +from .transformers import CogVideoXTransformer3DModel diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/activations.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/activations.py new file mode 100644 index 0000000000..7cd6938b22 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/activations.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate +from diffusers.utils.import_utils import is_torch_npu_available + + +if is_torch_npu_available(): + import torch_npu + +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU(), +} + + +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class FP32SiLU(nn.Module): + r""" + SiLU activation function with input upcasted to torch.float32. + """ + + def __init__(self): + super().__init__() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return F.silu(inputs.float(), inplace=False).to(inputs.dtype) + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + hidden_states = self.proj(hidden_states) + if is_torch_npu_available(): + # using torch_npu.npu_geglu can run faster and save memory on NPU. + return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] + else: + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class SwiGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU` + but uses SiLU / Swish instead of GeLU. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.activation = nn.SiLU() + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.activation(gate) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/attention.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/attention.py new file mode 100644 index 0000000000..df242f8b01 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/attention.py @@ -0,0 +1,1228 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import maybe_allow_in_graph +from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU +from .attention_processor import Attention, JointAttnProcessor2_0 +from .embeddings import SinusoidalPositionalEmbedding +from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX + +logger = logging.get_logger(__name__) + + +def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + return ff_output + + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + + +@maybe_allow_in_graph +class JointTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + context_pre_only: bool = False, + qk_norm: Optional[str] = None, + use_dual_attention: bool = False, + ): + super().__init__() + + self.use_dual_attention = use_dual_attention + self.context_pre_only = context_pre_only + context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" + + if use_dual_attention: + self.norm1 = SD35AdaLayerNormZeroX(dim) + else: + self.norm1 = AdaLayerNormZero(dim) + + if context_norm_type == "ada_norm_continous": + self.norm1_context = AdaLayerNormContinuous( + dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" + ) + elif context_norm_type == "ada_norm_zero": + self.norm1_context = AdaLayerNormZero(dim) + else: + raise ValueError( + f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" + ) + + if hasattr(F, "scaled_dot_product_attention"): + processor = JointAttnProcessor2_0() + else: + raise ValueError( + "The current PyTorch version does not support the `scaled_dot_product_attention` function." + ) + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=context_pre_only, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=1e-6, + ) + + if use_dual_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=1e-6, + ) + else: + self.attn2 = None + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + if not context_pre_only: + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + else: + self.norm2_context = None + self.ff_context = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + ): + if self.use_dual_attention: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( + hidden_states, emb=temb + ) + else: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + if self.context_pre_only: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + if self.use_dual_attention: + attn_output2 = self.attn2(hidden_states=norm_hidden_states2) + attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 + hidden_states = hidden_states + attn_output2 + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + if self.context_pre_only: + encoder_hidden_states = None + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + context_ff_output = _chunked_feed_forward( + self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return encoder_hidden_states, hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": # For Latte + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class LuminaFeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden + dimension. Defaults to None. + """ + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + inner_dim = int(2 * inner_dim / 3) + # custom hidden_size factor multiplier + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.linear_2 = nn.Linear( + inner_dim, + dim, + bias=False, + ) + self.linear_3 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.silu = FP32SiLU() + + def forward(self, x): + return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x)) + + +@maybe_allow_in_graph +class TemporalBasicTransformerBlock(nn.Module): + r""" + A basic Transformer block for video like data. + + Parameters: + dim (`int`): The number of channels in the input and output. + time_mix_inner_dim (`int`): The number of channels for temporal attention. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + """ + + def __init__( + self, + dim: int, + time_mix_inner_dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.is_res = dim == time_mix_inner_dim + + self.norm_in = nn.LayerNorm(dim) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.ff_in = FeedForward( + dim, + dim_out=time_mix_inner_dim, + activation_fn="geglu", + ) + + self.norm1 = nn.LayerNorm(time_mix_inner_dim) + self.attn1 = Attention( + query_dim=time_mix_inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + cross_attention_dim=None, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = nn.LayerNorm(time_mix_inner_dim) + self.attn2 = Attention( + query_dim=time_mix_inner_dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(time_mix_inner_dim) + self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = None + + def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): + # Sets chunk feed-forward + self._chunk_size = chunk_size + # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off + self._chunk_dim = 1 + + def forward( + self, + hidden_states: torch.Tensor, + num_frames: int, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + batch_frames, seq_length, channels = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + + if self._chunk_size is not None: + hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) + else: + hidden_states = self.ff_in(hidden_states) + + if self.is_res: + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.is_res: + hidden_states = ff_output + hidden_states + else: + hidden_states = ff_output + + hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) + + return hidden_states + + +class SkipFFTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + kv_input_dim: int, + kv_input_dim_proj_use_bias: bool, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + attention_out_bias: bool = True, + ): + super().__init__() + if kv_input_dim != dim: + self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) + else: + self.kv_mapper = None + + self.norm1 = RMSNorm(dim, 1e-06) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + out_bias=attention_out_bias, + ) + + self.norm2 = RMSNorm(dim, 1e-06) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + out_bias=attention_out_bias, + ) + + def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + if self.kv_mapper is not None: + encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) + + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + return hidden_states + + +@maybe_allow_in_graph +class FreeNoiseTransformerBlock(nn.Module): + r""" + A FreeNoise Transformer block. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + cross_attention_dim (`int`, *optional*): + The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (`int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (`bool`, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, defaults to `False`): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, defaults to `False`): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, defaults to `False`): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + ff_inner_dim (`int`, *optional*): + Hidden dimension of feed-forward MLP. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in feed-forward MLP. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in attention output project layer. + context_length (`int`, defaults to `16`): + The maximum number of frames that the FreeNoise block processes at once. + context_stride (`int`, defaults to `4`): + The number of frames to be skipped before starting to process a new batch of `context_length` frames. + weighting_scheme (`str`, defaults to `"pyramid"`): + The weighting scheme to use for weighting averaging of processed latent frames. As described in the + Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting + used. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + final_dropout: bool = False, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + context_length: int = 16, + context_stride: int = 4, + weighting_scheme: str = "pyramid", + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + self.set_free_noise_properties(context_length, context_stride, weighting_scheme) + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: + frame_indices = [] + for i in range(0, num_frames - self.context_length + 1, self.context_stride): + window_start = i + window_end = min(num_frames, i + self.context_length) + frame_indices.append((window_start, window_end)) + return frame_indices + + def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: + if weighting_scheme == "flat": + weights = [1.0] * num_frames + + elif weighting_scheme == "pyramid": + if num_frames % 2 == 0: + # num_frames = 4 => [1, 2, 2, 1] + mid = num_frames // 2 + weights = list(range(1, mid + 1)) + weights = weights + weights[::-1] + else: + # num_frames = 5 => [1, 2, 3, 2, 1] + mid = (num_frames + 1) // 2 + weights = list(range(1, mid)) + weights = weights + [mid] + weights[::-1] + + elif weighting_scheme == "delayed_reverse_sawtooth": + if num_frames % 2 == 0: + # num_frames = 4 => [0.01, 2, 2, 1] + mid = num_frames // 2 + weights = [0.01] * (mid - 1) + [mid] + weights = weights + list(range(mid, 0, -1)) + else: + # num_frames = 5 => [0.01, 0.01, 3, 2, 1] + mid = (num_frames + 1) // 2 + weights = [0.01] * mid + weights = weights + list(range(mid, 0, -1)) + else: + raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") + + return weights + + def set_free_noise_properties( + self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid" + ) -> None: + self.context_length = context_length + self.context_stride = context_stride + self.weighting_scheme = weighting_scheme + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None: + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + # hidden_states: [B x H x W, F, C] + device = hidden_states.device + dtype = hidden_states.dtype + + num_frames = hidden_states.size(1) + frame_indices = self._get_frame_indices(num_frames) + frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme) + frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) + is_last_frame_batch_complete = frame_indices[-1][1] == num_frames + + if not is_last_frame_batch_complete: + if num_frames < self.context_length: + raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}") + last_frame_batch_length = num_frames - frame_indices[-1][1] + frame_indices.append((num_frames - self.context_length, num_frames)) + + num_times_accumulated = torch.zeros((1, num_frames, 1), device=device) + accumulated_values = torch.zeros_like(hidden_states) + + for i, (frame_start, frame_end) in enumerate(frame_indices): + # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle + # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or + # essentially a non-multiple of `context_length`. + weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end]) + weights *= frame_weights + + hidden_states_chunk = hidden_states[:, frame_start:frame_end] + + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states_chunk) + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + hidden_states_chunk = attn_output + hidden_states_chunk + if hidden_states_chunk.ndim == 4: + hidden_states_chunk = hidden_states_chunk.squeeze(1) + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states_chunk) + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states_chunk = attn_output + hidden_states_chunk + + if i == len(frame_indices) - 1 and not is_last_frame_batch_complete: + accumulated_values[:, -last_frame_batch_length:] += ( + hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:] + ) + num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length] + else: + accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights + num_times_accumulated[:, frame_start:frame_end] += weights + + hidden_states = torch.cat( + [ + torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) + for accumulated_split, num_times_split in zip( + accumulated_values.split(self.context_length, dim=1), + num_times_accumulated.split(self.context_length, dim=1), + ) + ], + dim=1, + ).to(dtype) + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "swiglu": + act_fn = SwiGLU(dim, inner_dim, bias=bias) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/attention_processor.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/attention_processor.py new file mode 100644 index 0000000000..15c4a30541 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/attention_processor.py @@ -0,0 +1,4320 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch_npu +import torch.nn.functional as F +import torch.distributed as dist +from torch import nn +torch.ops.load_library("./pta_plugin/build/libPTAExtensionOPS.so") + +from diffusers.image_processor import IPAdapterMaskProcessor +from diffusers.utils import deprecate, logging +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available +from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph +from ..utils.parallel_state import get_world_size, get_rank, get_sp_world_size, get_sp_group, all_gather_variable_with_group + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +MAX_TOKENS = 2147483647 + +if is_torch_npu_available(): + import torch_npu + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + pre_only=False, + elementwise_affine: bool = True, + ): + super().__init__() + + # To prevent circular import. + from .normalization import FP32LayerNorm, RMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "layer_norm_across_heads": + # Lumina applys qk norm across all heads + self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) + self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) + else: + self.norm_added_q = None + self.norm_added_k = None + + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: + r""" + Set whether to use npu flash attention from `torch_npu` or not. + + """ + if use_npu_flash_attention: + processor = AttnProcessorNPU() + else: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and is_custom_diffusion: + raise NotImplementedError( + f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + else: + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + if self.norm_cross is None: + raise ValueError("self.norm_cross must be defined to call self.norm_encoder_hidden_states") + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + raise ValueError("Unsupported condition") + + return encoder_hidden_states + + @torch.no_grad() + def fuse_projections(self, fuse=True): + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not self.is_cross_attention: + # fetch weight matrices. + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + # create a new single projection layer and copy over the weights. + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + + # handle added projections for SD3 and others. + if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype + ) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + + self.fused_projections = fuse + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CustomDiffusionAttnProcessor(nn.Module): + r""" + Processor for implementing attention for the Custom Diffusion method. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + """ + + def __init__( + self, + train_kv: bool = True, + train_q_out: bool = True, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + if self.train_q_out: + self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) + self.to_out_custom_diffusion = nn.ModuleList([]) + self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) + self.to_out_custom_diffusion.append(nn.Dropout(dropout)) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype) + else: + query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype)) + + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) + value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) + key = key.to(attn.to_q.weight.dtype) + value = value.to(attn.to_q.weight.dtype) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if crossattn: + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if self.train_q_out: + # linear proj + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + # dropout + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class AttnAddedKVProcessor: + r""" + Processor for performing attention-related computations with extra learnable key and value matrices for the text + encoder. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class AttnAddedKVProcessor2_0: + r""" + Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra + learnable key and value matrices for the text encoder. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query, out_dim=4) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key, out_dim=4) + value = attn.head_to_batch_dim(value, out_dim=4) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class JointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class PAGJointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # store the length of image patch sequences to create a mask that prevents interaction between patches + # similar to making the self-attention map an identity matrix + identity_block_size = hidden_states.shape[1] + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2) + + ################## original path ################## + batch_size = encoder_hidden_states_org.shape[0] + + # `sample` projections. + query_org = attn.to_q(hidden_states_org) + key_org = attn.to_k(hidden_states_org) + value_org = attn.to_v(hidden_states_org) + + # `context` projections. + encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org) + encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org) + encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org) + + # attention + query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1) + key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1) + value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1) + + inner_dim = key_org.shape[-1] + head_dim = inner_dim // attn.heads + query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states_org = F.scaled_dot_product_attention( + query_org, key_org, value_org, dropout_p=0.0, is_causal=False + ) + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query_org.dtype) + + # Split the attention outputs. + hidden_states_org, encoder_hidden_states_org = ( + hidden_states_org[:, : residual.shape[1]], + hidden_states_org[:, residual.shape[1] :], + ) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + if not attn.context_pre_only: + encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + ################## perturbed path ################## + batch_size = encoder_hidden_states_ptb.shape[0] + + # `sample` projections. + query_ptb = attn.to_q(hidden_states_ptb) + key_ptb = attn.to_k(hidden_states_ptb) + value_ptb = attn.to_v(hidden_states_ptb) + + # `context` projections. + encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb) + encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb) + encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb) + + # attention + query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1) + key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1) + value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1) + + inner_dim = key_ptb.shape[-1] + head_dim = inner_dim // attn.heads + query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # create a full mask with all entries set to 0 + seq_len = query_ptb.size(2) + full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype) + + # set the attention value between image patches to -inf + full_mask[:identity_block_size, :identity_block_size] = float("-inf") + + # set the diagonal of the attention value between image patches to 0 + full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0) + + # expand the mask to match the attention weights shape + full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions + + hidden_states_ptb = F.scaled_dot_product_attention( + query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False + ) + hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype) + + # split the attention outputs. + hidden_states_ptb, encoder_hidden_states_ptb = ( + hidden_states_ptb[:, : residual.shape[1]], + hidden_states_ptb[:, residual.shape[1] :], + ) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + if not attn.context_pre_only: + encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + ################ concat ############### + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb]) + + return hidden_states, encoder_hidden_states + + +class PAGCFGJointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + identity_block_size = hidden_states.shape[ + 1 + ] # patch embeddings width * height (correspond to self-attention map width or height) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + ( + encoder_hidden_states_uncond, + encoder_hidden_states_org, + encoder_hidden_states_ptb, + ) = encoder_hidden_states.chunk(3) + encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org]) + + ################## original path ################## + batch_size = encoder_hidden_states_org.shape[0] + + # `sample` projections. + query_org = attn.to_q(hidden_states_org) + key_org = attn.to_k(hidden_states_org) + value_org = attn.to_v(hidden_states_org) + + # `context` projections. + encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org) + encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org) + encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org) + + # attention + query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1) + key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1) + value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1) + + inner_dim = key_org.shape[-1] + head_dim = inner_dim // attn.heads + query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states_org = F.scaled_dot_product_attention( + query_org, key_org, value_org, dropout_p=0.0, is_causal=False + ) + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query_org.dtype) + + # Split the attention outputs. + hidden_states_org, encoder_hidden_states_org = ( + hidden_states_org[:, : residual.shape[1]], + hidden_states_org[:, residual.shape[1] :], + ) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + if not attn.context_pre_only: + encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + ################## perturbed path ################## + batch_size = encoder_hidden_states_ptb.shape[0] + + # `sample` projections. + query_ptb = attn.to_q(hidden_states_ptb) + key_ptb = attn.to_k(hidden_states_ptb) + value_ptb = attn.to_v(hidden_states_ptb) + + # `context` projections. + encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb) + encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb) + encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb) + + # attention + query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1) + key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1) + value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1) + + inner_dim = key_ptb.shape[-1] + head_dim = inner_dim // attn.heads + query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # create a full mask with all entries set to 0 + seq_len = query_ptb.size(2) + full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype) + + # set the attention value between image patches to -inf + full_mask[:identity_block_size, :identity_block_size] = float("-inf") + + # set the diagonal of the attention value between image patches to 0 + full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0) + + # expand the mask to match the attention weights shape + full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions + + hidden_states_ptb = F.scaled_dot_product_attention( + query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False + ) + hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype) + + # split the attention outputs. + hidden_states_ptb, encoder_hidden_states_ptb = ( + hidden_states_ptb[:, : residual.shape[1]], + hidden_states_ptb[:, residual.shape[1] :], + ) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + if not attn.context_pre_only: + encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + ################ concat ############### + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb]) + + return hidden_states, encoder_hidden_states + + +class FusedJointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + # `context` projections. + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + # attention + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + + +class AuraFlowAttnProcessor2_0: + """Attention processor used typically in processing Aura Flow.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): + raise ImportError( + "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # Reshape. + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # Apply QK norm. + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Concatenate the projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) + + query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Attention. + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, encoder_hidden_states.shape[1] :], + hidden_states[:, : encoder_hidden_states.shape[1]], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FusedAuraFlowAttnProcessor2_0: + """Attention processor used typically in processing Aura Flow with fused projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): + raise ImportError( + "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + # Reshape. + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # Apply QK norm. + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Concatenate the projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) + + query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Attention. + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, encoder_hidden_states.shape[1] :], + hidden_states[:, : encoder_hidden_states.shape[1]], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FusedFluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class CogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + latent_seq_length = hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if hasattr(attn, "qkvLinear"): + query, key, value = attn.qkvLinear(hidden_states) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + query[:, :, text_seq_length:] = torch.ops.mindie.rope_mindie_sd(query[:, :, text_seq_length:], cos[None, None], sin[None, None], mode=1) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = torch.ops.mindie.rope_mindie_sd(key[:, :, text_seq_length:], cos[None, None], sin[None, None], mode=1) + + if get_sp_world_size() == 1: + hidden_states = torch_npu.npu_prompt_flash_attention( + query, key, value, num_heads=attn.heads, + input_layout='BNSD', + scale_value=1.0 / math.sqrt(query.shape[-1]), + atten_mask=attention_mask, + pre_tokens=MAX_TOKENS, + next_tokens=MAX_TOKENS, + sparse_mode=0 + ) + else: + hidden_states = gather_parrellel_ga(query, key, value, 1.0 / math.sqrt(query.shape[-1]), get_sp_world_size()) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, latent_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +def gather_parrellel_ga( + q, k, v, + scale_value, + world_size, + num_head_split=8, +): + """ + All Gather key-value pairs in parallel for Flash attention . + + + Args: + qkv_list (List[torch.Tensor]): A list containing query (q), key (k), and value (v) tensors. + the key and value should in the shape [B N S D] + head_dim (int): The dimension of each attention head. + world_size (int): The number of distributed processes. + num_head_split (int, optional): The number of splits for the attention heads. Defaults to 8. + + Returns: + torch.Tensor: The output tensor after applying parallel attention. + The shape [B N S D] + """ + local_size = torch.tensor([k.size(-2)], dtype=torch.long, device=k.device) + size_list = [torch.zeros(1, dtype=torch.long, device=k.device) for _ in range(get_sp_world_size())] + dist.all_gather(size_list, local_size, group=get_sp_group()) + sizes = [int(size.item()) for size in size_list] + + max_size = max(sizes) + pad_num = 0 + if k.size(-2) < max_size: + pad_size = list(k.size()) + pad_size[-2] = max_size - k.size(-2) + pad_num = pad_size[-2] + k = torch.cat([k, torch.zeros(pad_size, device=k.device, dtype=k.dtype)], dim=-2).contiguous() + v = torch.cat([v, torch.zeros(pad_size, device=v.device, dtype=v.dtype)], dim=-2).contiguous() + + + q_list = q.chunk(num_head_split, dim=1) + + kv = torch.cat((k, v), dim=0) + kv_list = kv.chunk(num_head_split, dim=1) + kv_split = kv_list[0].contiguous() + b, n, s, d = kv_split.shape + kv_full = torch.empty([world_size, b, n, s, d], dtype=kv_split.dtype, device=kv_split.device) + torch.distributed.all_gather_into_tensor(kv_full, kv_split, group=get_sp_group()) + kv_full = kv_full.permute(1, 2, 0, 3, 4).reshape(b, n, -1, d) + + out = [] + for step in range(num_head_split): + k, v = kv_full.chunk(2, dim=0) + if step != num_head_split - 1: + kv_split = kv_list[step + 1].contiguous() + b, n, s, d = kv_split.shape + kv_full = torch.empty([world_size, b, n, s, d], dtype=kv_split.dtype, device=kv_split.device) + req = torch.distributed.all_gather_into_tensor(kv_full, kv_split, async_op=True, group=get_sp_group()) + + output = torch_npu.npu_prompt_flash_attention( + q_list[step], + torch.narrow(k, dim=-2, start=0, length=k.size(-2) - pad_num), + torch.narrow(v, dim=-2, start=0, length=v.size(-2) - pad_num), + num_heads=k.shape[1], + input_layout="BNSD", + scale_value=scale_value, + pre_tokens=MAX_TOKENS, + next_tokens=MAX_TOKENS + ) + + out.append(output) + + if step != num_head_split - 1: + req.wait() + kv_full = kv_full.permute(1, 2, 0, 3, 4).reshape(b, n, -1, d) + out = torch.cat(out, dim=1) + return out + + +class FusedCogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + latent_seq_length = hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +class XFormersAttnAddedKVProcessor: + r""" + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class XFormersAttnProcessor: + r""" + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, key_tokens, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) + if attention_mask is not None: + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessorNPU: + r""" + Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If + fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is + not significant. + + """ + + def __init__(self): + if not is_torch_npu_available(): + raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if query.dtype in (torch.float16, torch.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, + key, + value, + attn.heads, + input_layout="BNSD", + pse=None, + atten_mask=attention_mask, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class StableAudioAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def apply_partial_rotary_emb( + self, + x: torch.Tensor, + freqs_cis: Tuple[torch.Tensor], + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + rot_dim = freqs_cis[0].shape[-1] + x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:] + + x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2) + + out = torch.cat((x_rotated, x_unrotated), dim=-1) + return out + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = query.shape[-1] // attn.heads + kv_heads = key.shape[-1] // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + + if kv_heads != attn.heads: + # if GQA or MQA, repeat the key/value heads to reach the number of query heads. + heads_per_kv_head = attn.heads // kv_heads + key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) + value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if rotary_emb is not None: + query_dtype = query.dtype + key_dtype = key.dtype + query = query.to(torch.float32) + key = key.to(torch.float32) + + rot_dim = rotary_emb[0].shape[-1] + query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] + query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + + query = torch.cat((query_rotated, query_unrotated), dim=-1) + + if not attn.is_cross_attention: + key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] + key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + + key = torch.cat((key_rotated, key_unrotated), dim=-1) + + query = query.to(query_dtype) + key = key.to(key_dtype) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class HunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class FusedHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused + projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on + query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + query = attn.to_q(hidden_states) + + kv = attn.to_kv(encoder_hidden_states) + split_size = kv.shape[-1] // 2 + key, value = torch.split(kv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This + variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # 1. Original Path + batch_size, sequence_length, _ = ( + hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states_org + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # 2. Perturbed Path + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGCFGHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This + variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # 1. Original Path + batch_size, sequence_length, _ = ( + hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states_org + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # 2. Perturbed Path + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LuminaAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[torch.Tensor] = None, + key_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Apply Query-Key Norm if needed + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.view(batch_size, -1, attn.heads, head_dim) + + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply RoPE if needed + if query_rotary_emb is not None: + query = apply_rotary_emb(query, query_rotary_emb, use_real=False) + if key_rotary_emb is not None: + key = apply_rotary_emb(key, key_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Apply proportional attention if true + if key_rotary_emb is None: + softmax_scale = None + else: + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # perform Grouped-qurey Attention (GQA) + n_rep = attn.heads // kv_heads + if n_rep >= 1: + key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).to(dtype) + + return hidden_states + + +class FusedAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses + fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. + For cross-attention modules, key and value projection matrices are fused. + + + + This API is currently 🧪 experimental in nature and can change in future. + + + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + query = attn.to_q(hidden_states) + + kv = attn.to_kv(encoder_hidden_states) + split_size = kv.shape[-1] // 2 + key, value = torch.split(kv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CustomDiffusionXFormersAttnProcessor(nn.Module): + r""" + Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use + as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. + """ + + def __init__( + self, + train_kv: bool = True, + train_q_out: bool = False, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, + attention_op: Optional[Callable] = None, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.attention_op = attention_op + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + if self.train_q_out: + self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) + self.to_out_custom_diffusion = nn.ModuleList([]) + self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) + self.to_out_custom_diffusion.append(nn.Dropout(dropout)) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype) + else: + query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype)) + + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) + value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) + key = key.to(attn.to_q.weight.dtype) + value = value.to(attn.to_q.weight.dtype) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if crossattn: + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if self.train_q_out: + # linear proj + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + # dropout + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class CustomDiffusionAttnProcessor2_0(nn.Module): + r""" + Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled + dot-product attention. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + """ + + def __init__( + self, + train_kv: bool = True, + train_q_out: bool = True, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + if self.train_q_out: + self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) + self.to_out_custom_diffusion = nn.ModuleList([]) + self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) + self.to_out_custom_diffusion.append(nn.Dropout(dropout)) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states) + else: + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) + value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) + key = key.to(attn.to_q.weight.dtype) + value = value.to(attn.to_q.weight.dtype) + + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if crossattn: + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + inner_dim = hidden_states.shape[-1] + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if self.train_q_out: + # linear proj + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + # dropout + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class SlicedAttnProcessor: + r""" + Processor for implementing sliced attention. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size: int): + self.slice_size = slice_size + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + for i in range((batch_size_attention - 1) // self.slice_size + 1): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class SlicedAttnAddedKVProcessor: + r""" + Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__( + self, + attn: "Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + for i in range((batch_size_attention - 1) // self.slice_size + 1): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class SpatialNorm(nn.Module): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. + """ + + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +class IPAdapterAttnProcessor(nn.Module): + r""" + Attention processor for Multiple IP-Adapters. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): + The context length of the image features. + scale (`float` or List[`float`], defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + self.num_tokens = num_tokens + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + self.to_v_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.Tensor] = None, + ): + residual = hidden_states + + # separate ip_hidden_states from encoder_hidden_states + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, tuple): + encoder_hidden_states, ip_hidden_states = encoder_hidden_states + else: + deprecation_message = ( + "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." + ) + deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) + end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + [encoder_hidden_states[:, end_pos:, :]], + ) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, List): + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): + raise ValueError( + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" + ) + else: + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) + else: + ip_adapter_masks = [None] * len(self.scale) + + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ): + skip = False + if isinstance(scale, list): + if all(s == 0 for s in scale): + skip = True + elif scale == 0: + skip = True + if not skip: + if mask is not None: + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) + + hidden_states = hidden_states + scale * current_ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAdapterAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapter for PyTorch 2.0. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): + The context length of the image features. + scale (`float` or `List[float]`, defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + self.num_tokens = num_tokens + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + self.to_v_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.Tensor] = None, + ): + residual = hidden_states + + # separate ip_hidden_states from encoder_hidden_states + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, tuple): + encoder_hidden_states, ip_hidden_states = encoder_hidden_states + else: + deprecation_message = ( + "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." + ) + deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) + end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + [encoder_hidden_states[:, end_pos:, :]], + ) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, List): + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): + raise ValueError( + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" + ) + else: + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) + else: + ip_adapter_masks = [None] * len(self.scale) + + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ): + skip = False + if isinstance(scale, list): + if all(s == 0 for s in scale): + skip = True + elif scale == 0: + skip = True + if not skip: + if mask is not None: + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + _current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + scale * current_ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGIdentitySelfAttnProcessor2_0: + r""" + Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + PAG reference: https://arxiv.org/abs/2403.17377 + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGCFGIdentitySelfAttnProcessor2_0: + r""" + Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + PAG reference: https://arxiv.org/abs/2403.17377 + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + value = attn.to_v(hidden_states_ptb) + hidden_states_ptb = value + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAAttnProcessor: + def __init__(self): + pass + + +class LoRAAttnProcessor2_0: + def __init__(self): + pass + + +class LoRAXFormersAttnProcessor: + def __init__(self): + pass + + +class LoRAAttnAddedKVProcessor: + def __init__(self): + pass + + +class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead." + deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message) + super().__init__() + + +ADDED_KV_ATTENTION_PROCESSORS = ( + AttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + XFormersAttnAddedKVProcessor, +) + +CROSS_ATTENTION_PROCESSORS = ( + AttnProcessor, + AttnProcessor2_0, + XFormersAttnProcessor, + SlicedAttnProcessor, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, +) + +AttentionProcessor = Union[ + AttnProcessor, + AttnProcessor2_0, + FusedAttnProcessor2_0, + XFormersAttnProcessor, + SlicedAttnProcessor, + AttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + XFormersAttnAddedKVProcessor, + CustomDiffusionAttnProcessor, + CustomDiffusionXFormersAttnProcessor, + CustomDiffusionAttnProcessor2_0, + PAGCFGIdentitySelfAttnProcessor2_0, + PAGIdentitySelfAttnProcessor2_0, + PAGCFGHunyuanAttnProcessor2_0, + PAGHunyuanAttnProcessor2_0, +] diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/embeddings.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/embeddings.py new file mode 100644 index 0000000000..05b5adc153 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/embeddings.py @@ -0,0 +1,1808 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate +from .activations import FP32SiLU, get_activation +from .attention_processor import Attention + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + if len(timesteps.shape) != 1: + raise ValueError("Timesteps should be a 1d-array") + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_3d_sincos_pos_embed( + embed_dim: int, + spatial_size: Union[int, Tuple[int, int]], + temporal_size: int, + spatial_interpolation_scale: float = 1.0, + temporal_interpolation_scale: float = 1.0, +) -> np.ndarray: + r""" + Args: + embed_dim (`int`): + spatial_size (`int` or `Tuple[int, int]`): + temporal_size (`int`): + spatial_interpolation_scale (`float`, defaults to 1.0): + temporal_interpolation_scale (`float`, defaults to 1.0): + """ + if embed_dim % 4 != 0: + raise ValueError("`embed_dim` must be divisible by 4") + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + embed_dim_spatial = 3 * embed_dim // 4 + embed_dim_temporal = embed_dim // 4 + + # 1. Spatial + grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale + grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + + # 2. Temporal + grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) + + # 3. Concat + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] + + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] + + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] + return pos_embed + + +def get_2d_sincos_pos_embed( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding with support for SD3 cropping.""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=1, + pos_embed_type="sincos", + pos_embed_max_size=None, # For SD3 cropping + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + self.pos_embed_max_size = pos_embed_max_size + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + self.interpolation_scale = interpolation_scale + + # Calculate positional embeddings based on max size or default + if pos_embed_max_size: + grid_size = pos_embed_max_size + else: + grid_size = int(num_patches**0.5) + + if pos_embed_type is None: + self.pos_embed = None + elif pos_embed_type == "sincos": + pos_embed = get_2d_sincos_pos_embed( + embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + persistent = True if pos_embed_max_size else False + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) + else: + raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") + + def cropped_pos_embed(self, height, width): + """Crops positional embeddings for SD3 compatibility.""" + if self.pos_embed_max_size is None: + raise ValueError("`pos_embed_max_size` must be set for cropping.") + + height = height // self.patch_size + width = width // self.patch_size + if height > self.pos_embed_max_size: + raise ValueError( + f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + if width > self.pos_embed_max_size: + raise ValueError( + f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + + top = (self.pos_embed_max_size - height) // 2 + left = (self.pos_embed_max_size - width) // 2 + spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) + spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def forward(self, latent): + if self.pos_embed_max_size is not None: + height, width = latent.shape[-2:] + else: + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + if self.pos_embed is None: + return latent.to(latent.dtype) + # Interpolate or crop positional embeddings as needed + if self.pos_embed_max_size: + pos_embed = self.cropped_pos_embed(height, width) + else: + if self.height != height or self.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + return (latent + pos_embed).to(latent.dtype) + + +class LuminaPatchEmbed(nn.Module): + """2D Image to Patch Embedding with support for Lumina-T2X""" + + def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=embed_dim, + bias=bias, + ) + + def forward(self, x, freqs_cis): + """ + Patchifies and embeds the input tensor(s). + + Args: + x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded. + + Returns: + Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified + and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the + frequency tensor(s). + """ + freqs_cis = freqs_cis.to(x[0].device) + patch_height = patch_width = self.patch_size + batch_size, channel, height, width = x.size() + height_tokens, width_tokens = height // patch_height, width // patch_width + + x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute( + 0, 2, 4, 1, 3, 5 + ) + x = x.flatten(3) + x = self.proj(x) + x = x.flatten(1, 2) + + mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device) + + return ( + x, + mask, + [(height, width)] * batch_size, + freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0), + ) + + +class CogVideoXPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + embed_dim: int = 1920, + text_embed_dim: int = 4096, + bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_positional_embeddings: bool = True, + use_learned_positional_embeddings: bool = True, + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.embed_dim = embed_dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings + self.use_learned_positional_embeddings = use_learned_positional_embeddings + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + self.text_proj = nn.Linear(text_embed_dim, embed_dim) + + if use_positional_embeddings or use_learned_positional_embeddings: + persistent = use_learned_positional_embeddings + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) + self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) + + def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: + post_patch_height = sample_height // self.patch_size + post_patch_width = sample_width // self.patch_size + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 + num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + pos_embedding = get_3d_sincos_pos_embed( + self.embed_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + ) + pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) + joint_pos_embedding = torch.zeros( + 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False + ) + joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding) + + return joint_pos_embedding + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + r""" + Args: + text_embeds (`torch.Tensor`): + Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). + image_embeds (`torch.Tensor`): + Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). + """ + text_embeds = self.text_proj(text_embeds) + + batch, num_frames, channels, height, width = image_embeds.shape + image_embeds = image_embeds.reshape(-1, channels, height, width) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] + image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] + + embeds = torch.cat( + [text_embeds, image_embeds], dim=1 + ).contiguous() # [batch, seq_length + num_frames x height x width, channels] + + if self.use_positional_embeddings or self.use_learned_positional_embeddings: + if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height): + raise ValueError( + "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'." + "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + if ( + self.sample_height != height + or self.sample_width != width + or self.sample_frames != pre_time_compression_frames + ): + pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) + pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) + else: + pos_embedding = self.pos_embedding + + embeds = embeds + pos_embedding + + return embeds + + +class CogView3PlusPatchEmbed(nn.Module): + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + text_hidden_size: int = 4096, + pos_embed_max_size: int = 128, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.text_hidden_size = text_hidden_size + self.pos_embed_max_size = pos_embed_max_size + # Linear projection for image patches + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + + # Linear projection for text embeddings + self.text_proj = nn.Linear(text_hidden_size, hidden_size) + + pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size) + pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + + if height % self.patch_size != 0 or width % self.patch_size != 0: + raise ValueError("Height and width must be divisible by patch size") + + height = height // self.patch_size + width = width // self.patch_size + hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous() + hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size) + + # Project the patches + hidden_states = self.proj(hidden_states) + encoder_hidden_states = self.text_proj(encoder_hidden_states) + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # Calculate text_length + text_length = encoder_hidden_states.shape[1] + + image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1) + text_pos_embed = torch.zeros( + (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device + ) + pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...] + + return (hidden_states + pos_embed).to(hidden_states.dtype) + + +def get_3d_rotary_pos_embed( + embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin + + +def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): + """ + RoPE for image tokens with 2d structure. + + Args: + embed_dim: (`int`): + The embedding dimension size + crops_coords (`Tuple[int]`) + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the positional embedding. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. + """ + start, stop = crops_coords + grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): + if embed_dim % 4 != 0: + raise ValueError("embed_dim must be divisible by 4") + + # use half of dimensions to encode grid_h + emb_h = get_1d_rotary_pos_embed( + embed_dim // 2, grid[0].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) + emb_w = get_1d_rotary_pos_embed( + embed_dim // 2, grid[1].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) + + if use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) + return cos, sin + else: + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + return emb + + +def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0): + if embed_dim % 4 != 0: + raise ValueError("embed_dim must be divisible by 4") + + emb_h = get_1d_rotary_pos_embed( + embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor + ) # (H, D/4) + emb_w = get_1d_rotary_pos_embed( + embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor + ) # (W, D/4) + emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1) + emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1) + + emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + if dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + freqs_dtype = torch.float32 if is_mps else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + del self.weight + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.weight = self.W + del self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +class SinusoidalPositionalEmbedding(nn.Module): + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x + + +class ImagePositionalEmbeddings(nn.Module): + """ + Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the + height and width of the latent space. + + For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 + + For VQ-diffusion: + + Output vector embeddings are used as input for the transformer. + + Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. + + Args: + num_embed (`int`): + Number of embeddings for the latent pixels embeddings. + height (`int`): + Height of the latent image i.e. the number of height embeddings. + width (`int`): + Width of the latent image i.e. the number of width embeddings. + embed_dim (`int`): + Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. + """ + + def __init__( + self, + num_embed: int, + height: int, + width: int, + embed_dim: int, + ): + super().__init__() + + self.height = height + self.width = width + self.num_embed = num_embed + self.embed_dim = embed_dim + + self.emb = nn.Embedding(self.num_embed, embed_dim) + self.height_emb = nn.Embedding(self.height, embed_dim) + self.width_emb = nn.Embedding(self.width, embed_dim) + + def forward(self, index): + emb = self.emb(index) + + height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) + + # 1 x H x D -> 1 x H x 1 x D + height_emb = height_emb.unsqueeze(2) + + width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) + + # 1 x W x D -> 1 x 1 x W x D + width_emb = width_emb.unsqueeze(1) + + pos_emb = height_emb + width_emb + + # 1 x H x W x D -> 1 x L xD + pos_emb = pos_emb.view(1, self.height * self.width, -1) + + emb = emb + pos_emb[:, : emb.shape[1], :] + + return emb + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels: torch.LongTensor, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (self.training and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class TextImageProjection(nn.Module): + def __init__( + self, + text_embed_dim: int = 1024, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 10, + ): + super().__init__() + + self.num_image_text_embeds = num_image_text_embeds + self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) + self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + batch_size = text_embeds.shape[0] + + # image + image_text_embeds = self.image_embeds(image_embeds) + image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + + # text + text_embeds = self.text_proj(text_embeds) + + return torch.cat([image_text_embeds, text_embeds], dim=1) + + +class ImageProjection(nn.Module): + def __init__( + self, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 32, + ): + super().__init__() + + self.num_image_text_embeds = num_image_text_embeds + self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.Tensor): + batch_size = image_embeds.shape[0] + + # image + image_embeds = self.image_embeds(image_embeds) + image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + image_embeds = self.norm(image_embeds) + return image_embeds + + +class IPAdapterFullImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): + super().__init__() + from .attention import FeedForward + + self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.Tensor): + return self.norm(self.ff(image_embeds)) + + +class IPAdapterFaceIDImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): + super().__init__() + from .attention import FeedForward + + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.Tensor): + x = self.ff(image_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return self.norm(x) + + +class CombinedTimestepLabelEmbeddings(nn.Module): + def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) + + def forward(self, timestep, class_labels, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + class_labels = self.class_embedder(class_labels) # (N, D) + + conditioning = timesteps_emb + class_labels # (N, D) + + return conditioning + + +class CombinedTimestepTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + pooled_projections = self.text_embedder(pooled_projection) + + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, guidance, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning + + +class CogView3CombinedTimestepSizeEmbeddings(nn.Module): + def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): + super().__init__() + + self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim) + self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward( + self, + timestep: torch.Tensor, + original_size: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + + original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1) + crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1) + target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1) + + condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1) + + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + + conditioning = timesteps_emb + condition_emb + return conditioning + + +class HunyuanDiTAttentionPool(nn.Module): + # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 + + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + return x.squeeze(0) + + +class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): + def __init__( + self, + embedding_dim, + pooled_projection_dim=1024, + seq_len=256, + cross_attention_dim=2048, + use_style_cond_and_image_meta_size=True, + ): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + + self.pooler = HunyuanDiTAttentionPool( + seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim + ) + + # Here we use a default learned embedder layer for future extension. + self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size + if use_style_cond_and_image_meta_size: + self.style_embedder = nn.Embedding(1, embedding_dim) + extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + else: + extra_in_dim = pooled_projection_dim + + self.extra_embedder = PixArtAlphaTextProjection( + in_features=extra_in_dim, + hidden_size=embedding_dim * 4, + out_features=embedding_dim, + act_fn="silu_fp32", + ) + + def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) + + # extra condition1: text + pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) + + if self.use_style_cond_and_image_meta_size: + # extra condition2: image meta size embedding + image_meta_size = self.size_proj(image_meta_size.view(-1)) + image_meta_size = image_meta_size.to(dtype=hidden_dtype) + image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) + + # extra condition3: style embedding + style_embedding = self.style_embedder(style) # (N, embedding_dim) + + # Concatenate all extra vectors + extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) + else: + extra_cond = torch.cat([pooled_projections], dim=1) + + conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] + + return conditioning + + +class LuminaCombinedTimestepCaptionEmbedding(nn.Module): + def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256): + super().__init__() + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 + ) + + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) + + self.caption_embedder = nn.Sequential( + nn.LayerNorm(cross_attention_dim), + nn.Linear( + cross_attention_dim, + hidden_size, + bias=True, + ), + ) + + def forward(self, timestep, caption_feat, caption_mask): + # timestep embedding: + time_freq = self.time_proj(timestep) + time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) + + # caption condition embedding: + caption_mask_float = caption_mask.float().unsqueeze(-1) + caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1) + caption_feats_pool = caption_feats_pool.to(caption_feat) + caption_embed = self.caption_embedder(caption_feats_pool) + + conditioning = time_embed + caption_embed + + return conditioning + + +class TextTimeEmbedding(nn.Module): + def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): + super().__init__() + self.norm1 = nn.LayerNorm(encoder_dim) + self.pool = AttentionPooling(num_heads, encoder_dim) + self.proj = nn.Linear(encoder_dim, time_embed_dim) + self.norm2 = nn.LayerNorm(time_embed_dim) + + def forward(self, hidden_states): + hidden_states = self.norm1(hidden_states) + hidden_states = self.pool(hidden_states) + hidden_states = self.proj(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class TextImageTimeEmbedding(nn.Module): + def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) + self.text_norm = nn.LayerNorm(time_embed_dim) + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + # text + time_text_embeds = self.text_proj(text_embeds) + time_text_embeds = self.text_norm(time_text_embeds) + + # image + time_image_embeds = self.image_proj(image_embeds) + + return time_image_embeds + time_text_embeds + + +class ImageTimeEmbedding(nn.Module): + def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + self.image_norm = nn.LayerNorm(time_embed_dim) + + def forward(self, image_embeds: torch.Tensor): + # image + time_image_embeds = self.image_proj(image_embeds) + time_image_embeds = self.image_norm(time_image_embeds) + return time_image_embeds + + +class ImageHintTimeEmbedding(nn.Module): + def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + self.image_norm = nn.LayerNorm(time_embed_dim) + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 96, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(96, 96, 3, padding=1), + nn.SiLU(), + nn.Conv2d(96, 256, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(256, 4, 3, padding=1), + ) + + def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): + # image + time_image_embeds = self.image_proj(image_embeds) + time_image_embeds = self.image_norm(time_image_embeds) + hint = self.input_hint_block(hint) + return time_image_embeds, hint + + +class AttentionPooling(nn.Module): + # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 + + def __init__(self, num_heads, embed_dim, dtype=None): + super().__init__() + self.dtype = dtype + self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.num_heads = num_heads + self.dim_per_head = embed_dim // self.num_heads + + def forward(self, x): + bs, length, width = x.size() + + def shape(x): + x = x.view(bs, -1, self.num_heads, self.dim_per_head) + x = x.transpose(1, 2) + x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) + x = x.transpose(1, 2) + return x + + class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) + x = torch.cat([class_token, x], dim=1) + q = shape(self.q_proj(class_token)) + k = shape(self.k_proj(x)) + v = shape(self.v_proj(x)) + + scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + a = torch.einsum("bts,bcs->bct", weight, v) + a = a.reshape(bs, -1, 1).transpose(1, 2) + + return a[:, 0, :] + + +def get_fourier_embeds_from_boundingbox(embed_dim, box): + """ + Args: + embed_dim: int + box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline + Returns: + [B x N x embed_dim] tensor of positional embeddings + """ + + batch_size, num_boxes = box.shape[:2] + + emb = 100 ** (torch.arange(embed_dim) / embed_dim) + emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) + emb = emb * box.unsqueeze(-1) + + emb = torch.stack((emb.sin(), emb.cos()), dim=-1) + emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) + + return emb + + +class GLIGENTextBoundingboxProjection(nn.Module): + def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): + super().__init__() + self.positive_len = positive_len + self.out_dim = out_dim + + self.fourier_embedder_dim = fourier_freqs + self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy + + if isinstance(out_dim, tuple): + out_dim = out_dim[0] + + if feature_type == "text-only": + self.linears = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + elif feature_type == "text-image": + self.linears_text = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.linears_image = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) + + def forward( + self, + boxes, + masks, + positive_embeddings=None, + phrases_masks=None, + image_masks=None, + phrases_embeddings=None, + image_embeddings=None, + ): + masks = masks.unsqueeze(-1) + + # embedding position (it may includes padding as placeholder) + xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C + + # learnable null embedding + xyxy_null = self.null_position_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null + + # positionet with text only information + if positive_embeddings is not None: + # learnable null embedding + positive_null = self.null_positive_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null + + objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) + + # positionet with text and image information + else: + phrases_masks = phrases_masks.unsqueeze(-1) + image_masks = image_masks.unsqueeze(-1) + + # learnable null embedding + text_null = self.null_text_feature.view(1, 1, -1) + image_null = self.null_image_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null + image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null + + objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) + objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) + objs = torch.cat([objs_text, objs_image], dim=1) + + return objs + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) + resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) + aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) + aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) + conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + elif act_fn == "silu_fp32": + self.act_1 = FP32SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class IPAdapterPlusImageProjectionBlock(nn.Module): + def __init__( + self, + embed_dims: int = 768, + dim_head: int = 64, + heads: int = 16, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.ln0 = nn.LayerNorm(embed_dims) + self.ln1 = nn.LayerNorm(embed_dims) + self.attn = Attention( + query_dim=embed_dims, + dim_head=dim_head, + heads=heads, + out_bias=False, + ) + self.ff = nn.Sequential( + nn.LayerNorm(embed_dims), + FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), + ) + + def forward(self, x, latents, residual): + encoder_hidden_states = self.ln0(x) + latents = self.ln1(latents) + encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) + latents = self.attn(latents, encoder_hidden_states) + residual + latents = self.ff(latents) + latents + return latents + + +class IPAdapterPlusImageProjection(nn.Module): + """Resampler of IP-Adapter Plus. + + Args: + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_queries (int): + The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio + of feedforward network hidden + layer channels. Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 1024, + hidden_dims: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_queries: int = 8, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) + + self.proj_in = nn.Linear(embed_dims, hidden_dims) + + self.proj_out = nn.Linear(hidden_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList( + [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x (torch.Tensor): Input Tensor. + Returns: + torch.Tensor: Output Tensor. + """ + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for block in self.layers: + residual = latents + latents = block(x, latents, residual) + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class IPAdapterFaceIDPlusImageProjection(nn.Module): + """FacePerceiverResampler of IP-Adapter Plus. + + Args: + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + ffproj_ratio (float): The expansion ratio of feedforward network hidden + layer channels (for ID embeddings). Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 768, + hidden_dims: int = 1280, + id_embeddings_dim: int = 512, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_tokens: int = 4, + num_queries: int = 8, + ffn_ratio: float = 4, + ffproj_ratio: int = 2, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.num_tokens = num_tokens + self.embed_dim = embed_dims + self.clip_embeds = None + self.shortcut = False + self.shortcut_scale = 1.0 + + self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) + self.norm = nn.LayerNorm(embed_dims) + + self.proj_in = nn.Linear(hidden_dims, embed_dims) + + self.proj_out = nn.Linear(embed_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList( + [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + + def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + id_embeds (torch.Tensor): Input Tensor (ID embeds). + Returns: + torch.Tensor: Output Tensor. + """ + id_embeds = id_embeds.to(self.clip_embeds.dtype) + id_embeds = self.proj(id_embeds) + id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) + id_embeds = self.norm(id_embeds) + latents = id_embeds + + clip_embeds = self.proj_in(self.clip_embeds) + x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) + + for block in self.layers: + residual = latents + latents = block(x, latents, residual) + + latents = self.proj_out(latents) + out = self.norm_out(latents) + if self.shortcut: + out = id_embeds + self.shortcut_scale * out + return out + + +class MultiIPAdapterImageProjection(nn.Module): + def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): + super().__init__() + self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) + + def forward(self, image_embeds: List[torch.Tensor]): + projected_image_embeds = [] + + # currently, we accept `image_embeds` as + # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim] + # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim] + if not isinstance(image_embeds, list): + deprecation_message = ( + "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning." + ) + deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) + image_embeds = [image_embeds.unsqueeze(1)] + + if len(image_embeds) != len(self.image_projection_layers): + raise ValueError( + f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" + ) + + for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): + batch_size, num_images = image_embed.shape[0], image_embed.shape[1] + image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) + image_embed = image_projection_layer(image_embed) + image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) + + projected_image_embeds.append(image_embed) + + return projected_image_embeds diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/normalization.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/normalization.py new file mode 100644 index 0000000000..6e5c7ca0bf --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/normalization.py @@ -0,0 +1,527 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.utils import is_torch_version +from .activations import get_activation +from .embeddings import ( + CombinedTimestepLabelEmbeddings, + PixArtAlphaCombinedTimestepSizeEmbeddings, +) + + +class AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`, *optional*): The size of the embeddings dictionary. + output_dim (`int`, *optional*): + norm_elementwise_affine (`bool`, defaults to `False): + norm_eps (`bool`, defaults to `False`): + chunk_dim (`int`, defaults to `0`): + """ + + def __init__( + self, + embedding_dim: int, + num_embeddings: Optional[int] = None, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + + self.chunk_dim = chunk_dim + output_dim = output_dim or embedding_dim * 2 + + if num_embeddings is not None: + self.emb = nn.Embedding(num_embeddings, embedding_dim) + else: + self.emb = None + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.emb is not None: + temb = self.emb(timestep) + + temb = self.linear(self.silu(temb)) + + if self.chunk_dim == 1: + # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the + # other if-branch. This branch is specific to CogVideoX for now. + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + else: + scale, shift = temb.chunk(2, dim=0) + + x = self.norm(x) * (1 + scale) + shift + return x + + +class FP32LayerNorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm( + inputs.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +class SD35AdaLayerNormZeroX(nn.Module): + r""" + Norm layer adaptive layer norm zero (AdaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.") + + def forward( + self, + hidden_states: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, ...]: + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk( + 9, dim=1 + ) + norm_hidden_states = self.norm(hidden_states) + hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None] + norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None] + return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 + + +class AdaLayerNormZero(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True): + super().__init__() + if num_embeddings is not None: + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + else: + self.emb = None + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + elif norm_type == "fp32_layer_norm": + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + x: torch.Tensor, + timestep: Optional[torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + hidden_dtype: Optional[torch.dtype] = None, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if self.emb is not None: + emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLayerNormZeroSingle(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa + + +class LuminaRMSNormZero(nn.Module): + """ + Norm layer adaptive RMS normalization zero. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + + return x, gate_msa, scale_mlp, gate_mlp + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class AdaGroupNorm(nn.Module): + r""" + GroupNorm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + num_groups (`int`): The number of groups to separate the channels into. + act_fn (`str`, *optional*, defaults to `None`): The activation function to use. + eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability. + """ + + def __init__( + self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 + ): + super().__init__() + self.num_groups = num_groups + self.eps = eps + + if act_fn is None: + self.act = None + else: + self.act = get_activation(act_fn) + + self.linear = nn.Linear(embedding_dim, out_dim * 2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + if self.act: + emb = self.act(emb) + emb = self.linear(emb) + emb = emb[:, :, None, None] + scale, shift = emb.chunk(2, dim=1) + + x = F.group_norm(x, self.num_groups, eps=self.eps) + x = x * (1 + scale) + shift + return x + + +class AdaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class LuminaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: Optional[int] = None, + ): + super().__init__() + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + else: + raise ValueError(f"unknown norm_type {norm_type}") + # linear_2 + if out_dim is not None: + self.linear_2 = nn.Linear( + embedding_dim, + out_dim, + bias=bias, + ) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + scale = emb + x = self.norm(x) * (1 + scale)[:, None, :] + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x + + +class CogView3PlusAdaLayerNormZeroTextImage(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, dim: int): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) + self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + c_shift_msa, + c_scale_msa, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + normed_x = self.norm_x(x) + normed_context = self.norm_c(context) + x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None] + context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp + + +class CogVideoXLayerNormZero(nn.Module): + def __init__( + self, + conditioning_dim: int, + embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) + self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] + return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] + + +if is_torch_version(">=", "2.1.0"): + LayerNorm = nn.LayerNorm +else: + class LayerNorm(nn.Module): + def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + self.bias = nn.Parameter(torch.zeros(dim)) if bias else None + else: + self.weight = None + self.bias = None + + def forward(self, input): + return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + +class GlobalResponseNorm(nn.Module): + # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * nx) + self.beta + x diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/transformers/__init__.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/transformers/__init__.py new file mode 100644 index 0000000000..d5899e9cd7 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/transformers/__init__.py @@ -0,0 +1 @@ +from .cogvideox_transformer_3d import CogVideoXTransformer3DModel diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py new file mode 100644 index 0000000000..b0a007d097 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py @@ -0,0 +1,551 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from mindiesd.layers.linear import QKVLinear +from ..attention import Attention, FeedForward +from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero +from ...utils import all_gather_variable_with_group, split_tensor, get_dp_world_size +from ...utils import get_sp_world_size, get_sp_group, get_dp_group, get_rank + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class CogVideoXBlock(nn.Module): + r""" + Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + time_embed_dim (`int`): + The number of channels in timestep embedding. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + qk_norm (`bool`, defaults to `True`): + Whether or not to use normalization after query and key projections in Attention. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*, defaults to `None`): + Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in Feed-forward layer. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in Attention output projection layer. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + dropout: float = 0.0, + activation_fn: str = "gelu-approximate", + attention_bias: bool = False, + qk_norm: bool = True, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + final_dropout: bool = True, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + + # 1. Self Attention + self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.attn1 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=attention_bias, + out_bias=attention_out_bias, + processor=CogVideoXAttnProcessor2_0(), + ) + + # 2. Feed Forward + self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + hidden_states, encoder_hidden_states, temb + ) + + # attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + gate_msa * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + hidden_states, encoder_hidden_states, temb + ) + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + """ + A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). + + Parameters: + num_attention_heads (`int`, defaults to `30`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `16`): + The number of channels in the output. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_width (`int`, defaults to `90`): + The width of the input latents. + sample_height (`int`, defaults to `60`): + The height of the input latents. + sample_frames (`int`, defaults to `49`): + The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 + instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, + but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with + K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + temporal_compression_ratio (`int`, defaults to `4`): + The compression ratio across the temporal dimension. See documentation for `sample_frames`. + max_text_seq_length (`int`, defaults to `226`): + The maximum sequence length of the input text embeddings. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + timestep_activation_fn (`str`, defaults to `"silu"`): + Activation function to use when generating the timestep embeddings. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-5`): + The epsilon value to use in normalization layers. + spatial_interpolation_scale (`float`, defaults to `1.875`): + Scaling factor to apply in 3D positional embeddings across spatial dimensions. + temporal_interpolation_scale (`float`, defaults to `1.0`): + Scaling factor to apply in 3D positional embeddings across temporal dimensions. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: int = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + self.num_heads = num_attention_heads + self.head_dim = attention_head_dim + + if not use_rotary_positional_embeddings and use_learned_positional_embeddings: + raise ValueError( + "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " + "embeddings. If you're using a custom model and/or believe this should be supported, please open an " + "issue at https://github.com/huggingface/diffusers/issues." + ) + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) + self.embedding_dropout = nn.Dropout(dropout) + + # 2. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 3. Define spatio-temporal transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + + # 4. Output blocks + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + chunk_dim=1, + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def switch_to_qkvLinear(self) -> None: + for blk in self.transformer_blocks: + blk.attn1.qkvLinear = QKVLinear(self.head_dim, self.head_dim * self.num_heads) + blk.attn1.qkvLinear.weight.data = torch.cat((blk.attn1.to_q.weight.data.transpose(1, 0).contiguous(), blk.attn1.to_k.weight.data.transpose(1, 0).contiguous(), blk.attn1.to_v.weight.data.transpose(1, 0).contiguous()), -1) + blk.attn1.qkvLinear.bias.data = torch.cat((blk.attn1.to_q.bias.data, blk.attn1.to_k.bias.data, blk.attn1.to_v.bias.data), -1) + blk.attn1.to_q = None + blk.attn1.to_k = None + blk.attn1.to_v = None + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + temporal_size = hidden_states.shape[1] + if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]: + timestep = split_tensor(timestep, 0, get_dp_world_size(), get_dp_group()) + + hidden_states = split_tensor(hidden_states, 0, get_dp_world_size(), get_dp_group()) + hidden_states = split_tensor(hidden_states, -2, get_sp_world_size(), get_sp_group(), scale=2) + + encoder_hidden_states = split_tensor(encoder_hidden_states, 0, get_dp_world_size(), get_dp_group()) + encoder_hidden_states = split_tensor(encoder_hidden_states, -2, get_sp_world_size(), get_sp_group()) + + if image_rotary_emb is not None: + freqs_cos, freqs_sin = image_rotary_emb + + def get_rotary_emb_chunk(freqs): + dim_thw = freqs.shape[-1] + freqs = freqs.reshape(temporal_size, -1, dim_thw) + + freqs = freqs.reshape(temporal_size, -1, hidden_states.size(-1) // 2, dim_thw) + freqs = split_tensor(freqs, -3, get_sp_world_size(), get_sp_group()) + freqs = freqs.reshape(temporal_size, -1, dim_thw) + + freqs = freqs.reshape(-1, dim_thw) + return freqs + + freqs_cos = get_rotary_emb_chunk(freqs_cos) + freqs_sin = get_rotary_emb_chunk(freqs_sin) + image_rotary_emb = (freqs_cos, freqs_sin) + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + output = all_gather_variable_with_group(output, dim=-2, world_size=get_sp_world_size(), group=get_sp_group()) + output = all_gather_variable_with_group(output, world_size=get_dp_world_size(), group=get_dp_group()) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) \ No newline at end of file diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/pipelines/__init__.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/pipelines/__init__.py new file mode 100644 index 0000000000..1032118c1e --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/pipelines/__init__.py @@ -0,0 +1 @@ +from .pipeline_cogvideox import CogVideoXPipeline diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/pipelines/pipeline_cogvideox.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/pipelines/pipeline_cogvideox.py new file mode 100644 index 0000000000..323bce7435 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/pipelines/pipeline_cogvideox.py @@ -0,0 +1,759 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import torch +from transformers import T5EncoderModel, T5Tokenizer +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.loaders import CogVideoXLoraLoaderMixin +from diffusers.models import AutoencoderKLCogVideoX +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils import logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from ..models import CogVideoXTransformer3DModel +from ..models.embeddings import get_3d_rotary_pos_embed +from .pipeline_output import CogVideoXPipelineOutput +from ..utils.parallel_state import get_world_size, get_rank, all_gather, all_gather_variable_with_group, split_tensor +from ..utils.parallel_state import get_dp_world_size, get_dp_rank, get_sp_rank, get_sp_world_size + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae_scaling_factor_image * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) # 720/8/2 + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) # 480/8/2 + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device).to(torch.bfloat16) + freqs_sin = freqs_sin.to(device=device).to(torch.bfloat16) + + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + if hasattr(self, "skip_strategy"): + noise_pred = self.skip_strategy( + self.transformer, + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + else: + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + if hasattr(self, "skip_strategy"): + self.skip_strategy.update_strategy(latents) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents.half()) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/pipelines/pipeline_output.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/pipelines/pipeline_output.py new file mode 100644 index 0000000000..3de030dd69 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/pipelines/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class CogVideoXPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/__init__.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/__init__.py new file mode 100644 index 0000000000..72ecaa460a --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/__init__.py @@ -0,0 +1,2 @@ +from .parallel_state import get_rank, get_world_size, all_gather, split_tensor, all_gather_variable_with_group +from .parallel_state import get_dp_rank, get_dp_world_size, get_sp_rank, get_sp_world_size, get_sp_group, get_dp_group diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/parallel_mgr.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/parallel_mgr.py new file mode 100644 index 0000000000..311e99f9ac --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/parallel_mgr.py @@ -0,0 +1,76 @@ +import os +import torch +import torch_npu +import torch.distributed as dist +from torch_npu._C._distributed_c10d import ProcessGroupHCCL + + +def create_sp_group(world_size, rank): + ranks = [i for i in range(world_size)] + group1 = dist.new_group(ranks=ranks[:world_size // 2], backend='hccl') + group2 = dist.new_group(ranks=ranks[world_size // 2:], backend='hccl') + if rank < world_size // 2: + subgroup = group1 + else: + subgroup = group2 + print(f'rank: {rank}, ranks: {ranks}') + return subgroup + + +def create_dp_group(world_size, rank): + ranks = [i for i in range(world_size)] + sub_ranks = [[i, j] for i, j in zip(ranks[:world_size // 2], ranks[world_size // 2:])] + groups = [dist.new_group(ranks=sub_rank, backend='hccl') for sub_rank in sub_ranks] + rank = rank if rank < world_size // 2 else rank - world_size // 2 + return groups[rank] + + +class ParallelManager: + def __init__(self): + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + self.rank = local_rank + self.world_size = world_size + if self.world_size > 1: + self.init_group() + self.sp_group = None + self.dp_group = None + self.sp_rank = 0 + self.sp_world_size = 1 + self.dp_world_size = 1 + self.dp_rank = 0 + self.do_pad = False + if self.world_size == 2: + self.init_dp() + + if self.world_size == 4 or self.world_size == 8: + self.init_dp() + self.dp_group = create_dp_group(self.world_size, self.rank) + self.sp_group = create_sp_group(self.world_size, self.rank) + self.sp_rank = dist.get_rank(group=self.sp_group) + self.sp_world_size = dist.get_world_size(group=self.sp_group) + + def init_dp(self): + self.dp_world_size = 2 + self.dp_rank = int(self.rank >= (self.world_size // self.dp_world_size)) + + + def init_group(self): + device = torch.device(f"npu:{self.rank}") + torch_npu.npu.set_device(device) + + backend = "hccl" + options = ProcessGroupHCCL.Options() + print("ProcessGroupHCCL has been Set") + if not torch.distributed.is_initialized(): + # Call the init process. + torch.distributed.init_process_group( + backend=backend, + world_size=self.world_size, + rank=self.rank, + pg_options=options, + ) + print(f"rank {self.rank} init {torch.distributed.is_initialized()}, init_process_group has been activated") + else: + print("torch.distributed is already initialized.") + diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/parallel_state.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/parallel_state.py new file mode 100644 index 0000000000..fd09991d40 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/parallel_state.py @@ -0,0 +1,168 @@ +import math +import torch +import torch.distributed as dist + +from typing import Any, Dict, List, Optional, Tuple, Union + +from .parallel_mgr import ParallelManager + +mgr = ParallelManager() + + +def get_world_size(): + return mgr.world_size + + +def get_rank(): + return mgr.rank + + +def get_dp_world_size(): + return mgr.dp_world_size + + +def get_dp_rank(): + return mgr.dp_rank + + +def get_sp_world_size(): + return mgr.sp_world_size + + +def get_sp_rank(): + return mgr.sp_rank + + +def get_sp_group(): + return mgr.sp_group + + +def get_dp_group(): + return mgr.dp_group + + +def all_gather(input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False, world_size=1, group=None + ) -> Union[torch.Tensor, List[torch.Tensor]]: + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + if not (-input_.dim() <= dim < input_.dim()): + raise ValueError(f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty( + input_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=group + ) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape([world_size, ] + input_size) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + +def all_gather_variable_with_group(tensor, dim=0, world_size=1, group=None): + """ + 使用指定的 group 进行 all_gather 操作,支持第一维大小不同的张量。 + + Args: + tensor (torch.Tensor): 本地张量,第一维大小可能不同。 + dim (int): 拼接的维度, 默认是0。 + group (torch.distributed.ProcessGroup): 指定的进程组。 + + Returns: + torch.Tensor: 合并后的张量。 + """ + if world_size == 1: + return tensor + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + + # 获取当前张量的第一维大小 + local_size = torch.tensor([tensor.size(dim)], dtype=torch.long, device=tensor.device) + + # 收集所有进程的大小 + size_list = [torch.zeros(1, dtype=torch.long, device=tensor.device) for _ in range(world_size)] + dist.all_gather(size_list, local_size, group=group) + sizes = [int(size.item()) for size in size_list] + + # 找到最大大小 + max_size = max(sizes) + + # 如果当前张量小于最大大小,则填充 + if tensor.size(dim) < max_size: + padding_size = list(tensor.size()) + padding_size[dim] = max_size - tensor.size(dim) + padding = torch.zeros(*padding_size, dtype=tensor.dtype, device=tensor.device) + tensor_padded = torch.cat([tensor, padding], dim=dim) + else: + tensor_padded = tensor + + # 准备一个列表来存储所有填充后的张量 + tensor_list = [torch.zeros_like(tensor_padded) for _ in range(world_size)] + + # 执行 all_gather + dist.all_gather(tensor_list, tensor_padded, group=group) + + # 去除填充并拼接 + tensors = [] + for i, t in enumerate(tensor_list): + if sizes[i] > 0: + tensors.append(t.narrow(dim, 0, sizes[i])) + return torch.cat(tensors, dim=dim) + + +def split_tensor(input_tensor: torch.Tensor, dim: int, world_size: int, group: dist.ProcessGroup, scale=1): + """ + 将 input_tensor 沿指定维度 dim 切分为 group 中各个进程的部分。 + + 参数: + input_tensor (torch.Tensor): 输入的张量。 + group (torch.distributed.ProcessGroup): 当前的通信组。 + dim (int): 切分的维度。 + + 返回: + tuple: + - torch.Tensor: 当前进程对应的切分后的张量。 + - int or None: 如果切分等长,返回切分后的长度;否则返回 None。 + """ + if world_size == 1: + return input_tensor + + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + dim_size = input_tensor.size(dim) + + # 计算每个块的大小 + if dim_size / scale % world_size == 0: + split_size = dim_size // world_size + else: + split_size = math.ceil(dim_size / world_size / 2) * 2 + + chunks = torch.split(input_tensor, split_size, dim=dim) + + # 获取当前进程对应的块 + tensor_chunk = chunks[rank] + + return tensor_chunk diff --git a/MindIE/MultiModal/CogVideoX/inference.py b/MindIE/MultiModal/CogVideoX/inference.py new file mode 100644 index 0000000000..fad06ada29 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/inference.py @@ -0,0 +1,134 @@ +import os +import argparse +import time +import random + +from typing import Literal + +import numpy as np +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu + +from diffusers import CogVideoXDPMScheduler +from diffusers.utils import export_to_video + +from cogvideox_5b import CogVideoXPipeline, CogVideoXTransformer3DModel, get_rank, get_world_size, all_gather +from mindiesd.pipeline.sampling_optm import AdaStep + + +def generate_video( + prompt: str, + model_path: str, + lora_path: str = None, + lora_rank: int = 128, + num_frames: int = 81, + width: int = 1360, + height: int = 768, + output_path: str = "./output.mp4", + image_or_video_path: str = "", + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + num_videos_per_prompt: int = 1, + dtype: torch.dtype = torch.bfloat16, + generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video + seed: int = 42, + fps: int = 8 +): + pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(f"npu:{get_rank()}") + transformer = CogVideoXTransformer3DModel.from_pretrained(os.path.join(model_path, 'transformer'), torch_dtype=dtype).to(f"npu:{get_rank()}") + if lora_path: + pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1") + pipe.fuse_lora(lora_scale=1 / lora_rank) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + pipe.transformer = transformer + pipe.vae = pipe.vae.half() + pipe.vae.enable_slicing() + pipe.vae.enable_tiling() + pipe.transformer.switch_to_qkvLinear() + # sampling optm + skip_strategy = AdaStep(skip_thr=0.006, max_skip_steps=1, decay_ratio=0.99, device="npu") + pipe.skip_strategy = skip_strategy + + # warm up + video_generate = pipe( + height=height, + width=width, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + num_inference_steps=1, + num_frames=num_frames, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), + output_type="pil" + ).frames[0] + + torch_npu.npu.synchronize() + start = time.time() + video_generate = pipe( + height=height, + width=width, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + num_inference_steps=num_inference_steps, + num_frames=num_frames, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), + output_type="pil" + ).frames[0] + torch_npu.npu.synchronize() + end = time.time() + print(f"Time taken for inference: {end - start} seconds") + + export_to_video(video_generate, output_path, fps=fps) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") + parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") + parser.add_argument( + "--image_or_video_path", + type=str, + default=None, + help="The path of the image to be used as the background of the video", + ) + parser.add_argument( + "--model_path", type=str, default="/data/CogVideoX-5b", help="Path of the pre-trained model use" + ) + parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used") + parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights") + parser.add_argument("--output_path", type=str, default="./output.mp4", help="The path save generated video") + parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") + parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps") + parser.add_argument("--num_frames", type=int, default=48, help="Number of steps for the inference process") + parser.add_argument("--width", type=int, default=720, help="Number of steps for the inference process") + parser.add_argument("--height", type=int, default=480, help="Number of steps for the inference process") + parser.add_argument("--fps", type=int, default=8, help="Number of steps for the inference process") + parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") + parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation") + parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation") + parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility") + + args = parser.parse_args() + dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 + torch.npu.config.allow_internal_format = False + generate_video( + prompt=args.prompt, + model_path=args.model_path, + lora_path=args.lora_path, + lora_rank=args.lora_rank, + output_path=args.output_path, + num_frames=args.num_frames, + width=args.width, + height=args.height, + image_or_video_path=args.image_or_video_path, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + num_videos_per_prompt=args.num_videos_per_prompt, + dtype=dtype, + generate_type=args.generate_type, + seed=args.seed, + fps=args.fps, + ) diff --git a/MindIE/MultiModal/CogVideoX/pta_plugin/CMakeLists.txt b/MindIE/MultiModal/CogVideoX/pta_plugin/CMakeLists.txt new file mode 100644 index 0000000000..ff66b7724e --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/pta_plugin/CMakeLists.txt @@ -0,0 +1,30 @@ +cmake_minimum_required(VERSION 3.10) + +project(PTAExtensionOPS) + +execute_process( + COMMAND python3 -c "import site; print(site.getsitepackages()[0])" + OUTPUT_VARIABLE python_site_packages_path +) +string(STRIP "${python_site_packages_path}" python_site_packages_path) + +set(CMAKE_CXX_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fPIE -pie ${CMAKE_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS "-fabi-version=11 ${CMAKE_CXX_FLAGS}") +set(PYTORCH_INSTALL_PATH ${python_site_packages_path}/torch) +set(PYTORCH_NPU_INSTALL_PATH ${python_site_packages_path}/torch_npu) + +link_directories(${PYTORCH_INSTALL_PATH}/lib) +link_directories(${PYTORCH_NPU_INSTALL_PATH}/lib) + +add_library(PTAExtensionOPS SHARED extension_ops.cpp) + +target_compile_features(PTAExtensionOPS PRIVATE cxx_std_17) +target_compile_options(PTAExtensionOPS PRIVATE -D_GLIBCXX_USE_CXX11_ABI=0) + +include_directories(${PYTORCH_NPU_INSTALL_PATH}/include/third_party/acl/inc) +include_directories(${PYTORCH_NPU_INSTALL_PATH}/include) +include_directories(${PYTORCH_INSTALL_PATH}/include) +include_directories(${PYTORCH_INSTALL_PATH}/include/torch/csrc/distributed) +include_directories(${PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include) + +target_link_libraries(PTAExtensionOPS PUBLIC c10 torch torch_cpu torch_npu ) \ No newline at end of file diff --git a/MindIE/MultiModal/CogVideoX/pta_plugin/build.sh b/MindIE/MultiModal/CogVideoX/pta_plugin/build.sh new file mode 100644 index 0000000000..95d55f5ff2 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/pta_plugin/build.sh @@ -0,0 +1,19 @@ +#!/bin/bash +if [ -n "$ASCEND_INSTALL_PATH" ]; then + _ASCEND_INSTALL_PATH=$ASCEND_INSTALL_PATH +elif [ -n "$ASCEND_HOME_PATH" ]; then + _ASCEND_INSTALL_PATH=$ASCEND_HOME_PATH +else + if [ -d "$HOME/Ascend/ascend-toolkit/latest" ]; then + _ASCEND_INSTALL_PATH=$HOME/Ascend/ascend-toolkit/latest + else + _ASCEND_INSTALL_PATH=/usr/local/Ascend/ascend-toolkit/latest + fi +fi +source $_ASCEND_INSTALL_PATH/bin/setenv.bash + +set -e +rm -rf build +mkdir -p build +cmake -B build +cmake --build build -j \ No newline at end of file diff --git a/MindIE/MultiModal/CogVideoX/pta_plugin/extension_ops.cpp b/MindIE/MultiModal/CogVideoX/pta_plugin/extension_ops.cpp new file mode 100644 index 0000000000..548a9a7365 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/pta_plugin/extension_ops.cpp @@ -0,0 +1,69 @@ +/** + * @file extension_add.cpp + * + * Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ +#include +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/core/npu/NPUFormat.h" + +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using npu_preparation = at_npu::native::OpPreparation; +using npu_utils = at_npu::native::NpuUtils; +using namespace at; + +// flash_attention_tik +// register forward implementation for NPU device +at::Tensor rope_mindie_sd_impl_npu(const at::Tensor &x, const at::Tensor &cos, const at::Tensor &sin, int64_t mode=1) +{ + at::Tensor result = at_npu::native::empty_with_format(x.sizes(),x.options(),at_npu::native::get_npu_format(x)); + + at_npu::native::OpCommand cmd; + + cmd.Name("RotaryPositionEmbedding") + .Input(x) + .Input(cos) + .Input(sin) + .Output(result) + .Attr("mode", mode) + .Run(); + + return result; +} + +// register forward implementation for Meta device +at::Tensor rope_mindie_sd_impl_meta(const at::Tensor &x, const at::Tensor &cos, const at::Tensor &sin, int64_t mode) +{ + return empty_like(x); +} + + +// register the schemas for my_op and my_op_backward in the myops namespace +TORCH_LIBRARY(mindie, m) +{ + m.def("rope_mindie_sd(Tensor query, Tensor key, Tensor value, int mode) -> Tensor"); +} + +// register forward and backward implementations for the NPU device +// the device name used by the NPU device in PyTorch 2.1 and above is PrivateUse1. +// in versions below 2.1, XLA is used. If the version is below 2.1, PrivateUse1 needs to be changed to XLA. +TORCH_LIBRARY_IMPL(mindie, PrivateUse1, m) +{ + m.impl("rope_mindie_sd", &rope_mindie_sd_impl_npu); +} + +// bind the NPU's autograd implementation to the operation +// if the version is below PyTorch 2.1, AutogradPrivateUse1 needs to be changed to AutogradXLA. + +// register forward and backward implementations for the Meta device +TORCH_LIBRARY_IMPL(mindie, Meta, m) +{ + m.impl("rope_mindie_sd", &rope_mindie_sd_impl_meta); +} \ No newline at end of file diff --git a/MindIE/MultiModal/CogVideoX/pta_plugin/test/test_rope.py b/MindIE/MultiModal/CogVideoX/pta_plugin/test/test_rope.py new file mode 100644 index 0000000000..5d6f3425b4 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/pta_plugin/test/test_rope.py @@ -0,0 +1,25 @@ +#!/usr/bin/python3 +# coding=utf-8 +# +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +# =============================================================================== + +import torch +import torch.nn as nn +import torch_npu + +torch.ops.load_library("../build/libPTAExtensionOPS.so") + +if __name__ == "__main__": + torch.npu.set_device(0) + x = torch.randn((2, 48, 128, 64), device="npu") + cos = torch.randn((1, 1, 128, 64), device="npu") + sin = torch.randn((1, 1, 128, 64), device="npu") + + count = 5 + for i in range(count): + output = torch.ops.mindie.rope_mindie_sd(x, cos, sin, mode=1) \ No newline at end of file diff --git a/MindIE/MultiModal/CogVideoX/requirements.txt b/MindIE/MultiModal/CogVideoX/requirements.txt new file mode 100644 index 0000000000..fc4ace5c8f --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/requirements.txt @@ -0,0 +1,13 @@ +diffusers>=0.31.0 +accelerate>=1.1.1 +transformers>=4.46.2 +numpy==1.26.0 +torchvision>=0.20.0 +sentencepiece>=0.2.0 +SwissArmyTransformer>=0.4.12 +gradio>=5.5.0 +imageio>=2.35.1 +imageio-ffmpeg>=0.5.1 +openai>=1.54.0 +moviepy>=1.0.3 +scikit-video>=1.1.11 \ No newline at end of file -- Gitee