From 2df660a33ff0c54c79f0269e8fb7e1ccc478cf14 Mon Sep 17 00:00:00 2001 From: fan2956 Date: Fri, 1 Aug 2025 17:31:03 +0800 Subject: [PATCH 1/2] add wan22 --- MindIE/MultiModal/Wan2.2 | 1 + 1 file changed, 1 insertion(+) create mode 160000 MindIE/MultiModal/Wan2.2 diff --git a/MindIE/MultiModal/Wan2.2 b/MindIE/MultiModal/Wan2.2 new file mode 160000 index 0000000000..7ceec19913 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2 @@ -0,0 +1 @@ +Subproject commit 7ceec199137c6a0c336138182b34efbf4036a3d6 -- Gitee From 95433f6c9c5c87ecd69db54be5d70f9905ebb692 Mon Sep 17 00:00:00 2001 From: fan2956 Date: Fri, 1 Aug 2025 17:38:28 +0800 Subject: [PATCH 2/2] add wan2.2 --- MindIE/MultiModal/Wan2.2 | 1 - MindIE/MultiModal/Wan2.2/INSTALL.md | 55 + MindIE/MultiModal/Wan2.2/LICENSE.txt | 201 ++++ MindIE/MultiModal/Wan2.2/Makefile | 5 + MindIE/MultiModal/Wan2.2/README.md | 443 +++++++ MindIE/MultiModal/Wan2.2/generate.py | 657 +++++++++++ MindIE/MultiModal/Wan2.2/pyproject.toml | 66 ++ MindIE/MultiModal/Wan2.2/requirements.txt | 15 + MindIE/MultiModal/Wan2.2/tests/README.md | 6 + MindIE/MultiModal/Wan2.2/tests/test.sh | 91 ++ MindIE/MultiModal/Wan2.2/wan/__init__.py | 5 + .../MultiModal/Wan2.2/wan/configs/__init__.py | 43 + .../Wan2.2/wan/configs/shared_config.py | 20 + .../Wan2.2/wan/configs/wan_i2v_A14B.py | 37 + .../Wan2.2/wan/configs/wan_t2v_A14B.py | 37 + .../Wan2.2/wan/configs/wan_ti2v_5B.py | 36 + .../Wan2.2/wan/distributed/__init__.py | 1 + .../MultiModal/Wan2.2/wan/distributed/comm.py | 95 ++ .../MultiModal/Wan2.2/wan/distributed/fsdp.py | 43 + .../wan/distributed/group_coordinator.py | 597 ++++++++++ .../Wan2.2/wan/distributed/parallel_mgr.py | 342 ++++++ .../wan/distributed/sequence_parallel.py | 183 +++ .../Wan2.2/wan/distributed/tp_applicator.py | 329 ++++++ .../Wan2.2/wan/distributed/ulysses.py | 47 + .../MultiModal/Wan2.2/wan/distributed/util.py | 203 ++++ MindIE/MultiModal/Wan2.2/wan/image2video.py | 464 ++++++++ .../MultiModal/Wan2.2/wan/modules/__init__.py | 19 + .../Wan2.2/wan/modules/attention.py | 194 +++ .../Wan2.2/wan/modules/attn_layer.py | 178 +++ MindIE/MultiModal/Wan2.2/wan/modules/model.py | 582 +++++++++ MindIE/MultiModal/Wan2.2/wan/modules/t5.py | 513 ++++++++ .../Wan2.2/wan/modules/tokenizers.py | 82 ++ .../MultiModal/Wan2.2/wan/modules/vae2_1.py | 663 +++++++++++ .../MultiModal/Wan2.2/wan/modules/vae2_2.py | 1051 +++++++++++++++++ MindIE/MultiModal/Wan2.2/wan/text2video.py | 403 +++++++ .../MultiModal/Wan2.2/wan/textimage2video.py | 648 ++++++++++ .../MultiModal/Wan2.2/wan/utils/__init__.py | 12 + .../MultiModal/Wan2.2/wan/utils/fm_solvers.py | 859 ++++++++++++++ .../Wan2.2/wan/utils/fm_solvers_unipc.py | 804 +++++++++++++ .../Wan2.2/wan/utils/prompt_extend.py | 542 +++++++++ .../Wan2.2/wan/utils/qwen_vl_utils.py | 363 ++++++ .../Wan2.2/wan/utils/system_prompt.py | 147 +++ MindIE/MultiModal/Wan2.2/wan/utils/utils.py | 159 +++ .../Wan2.2/wan/vae_patch_parallel.py | 737 ++++++++++++ 44 files changed, 11977 insertions(+), 1 deletion(-) delete mode 160000 MindIE/MultiModal/Wan2.2 create mode 100644 MindIE/MultiModal/Wan2.2/INSTALL.md create mode 100644 MindIE/MultiModal/Wan2.2/LICENSE.txt create mode 100644 MindIE/MultiModal/Wan2.2/Makefile create mode 100644 MindIE/MultiModal/Wan2.2/README.md create mode 100644 MindIE/MultiModal/Wan2.2/generate.py create mode 100644 MindIE/MultiModal/Wan2.2/pyproject.toml create mode 100644 MindIE/MultiModal/Wan2.2/requirements.txt create mode 100644 MindIE/MultiModal/Wan2.2/tests/README.md create mode 100644 MindIE/MultiModal/Wan2.2/tests/test.sh create mode 100644 MindIE/MultiModal/Wan2.2/wan/__init__.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/configs/__init__.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/configs/shared_config.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/configs/wan_i2v_A14B.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/configs/wan_t2v_A14B.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/configs/wan_ti2v_5B.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/distributed/__init__.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/distributed/comm.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/distributed/fsdp.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/distributed/group_coordinator.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/distributed/parallel_mgr.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/distributed/sequence_parallel.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/distributed/tp_applicator.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/distributed/ulysses.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/distributed/util.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/image2video.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/modules/__init__.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/modules/attention.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/modules/attn_layer.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/modules/model.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/modules/t5.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/modules/tokenizers.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/modules/vae2_1.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/modules/vae2_2.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/text2video.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/textimage2video.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/utils/__init__.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/utils/fm_solvers.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/utils/fm_solvers_unipc.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/utils/prompt_extend.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/utils/qwen_vl_utils.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/utils/system_prompt.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/utils/utils.py create mode 100644 MindIE/MultiModal/Wan2.2/wan/vae_patch_parallel.py diff --git a/MindIE/MultiModal/Wan2.2 b/MindIE/MultiModal/Wan2.2 deleted file mode 160000 index 7ceec19913..0000000000 --- a/MindIE/MultiModal/Wan2.2 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7ceec199137c6a0c336138182b34efbf4036a3d6 diff --git a/MindIE/MultiModal/Wan2.2/INSTALL.md b/MindIE/MultiModal/Wan2.2/INSTALL.md new file mode 100644 index 0000000000..14c6295828 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/INSTALL.md @@ -0,0 +1,55 @@ +# Installation Guide + +## Install with pip + +```bash +pip install . +pip install .[dev] # Installe aussi les outils de dev +``` + +## Install with Poetry + +Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system. + +To install all dependencies: + +```bash +poetry install +``` + +### Handling `flash-attn` Installation Issues + +If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes. + +#### No-Build-Isolation Installation (Recommended) +```bash +poetry run pip install --upgrade pip setuptools wheel +poetry run pip install flash-attn --no-build-isolation +poetry install +``` + +#### Install from Git (Alternative) +```bash +poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git +``` + +--- + +### Running the Model + +Once the installation is complete, you can run **Wan2.2** using: + +```bash +poetry run python generate.py --task t2v-A14B --size '1280*720' --ckpt_dir ./Wan2.2-T2V-A14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." +``` + +#### Test +```bash +bash tests/test.sh +``` + +#### Format +```bash +black . +isort . +``` diff --git a/MindIE/MultiModal/Wan2.2/LICENSE.txt b/MindIE/MultiModal/Wan2.2/LICENSE.txt new file mode 100644 index 0000000000..261eeb9e9f --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MindIE/MultiModal/Wan2.2/Makefile b/MindIE/MultiModal/Wan2.2/Makefile new file mode 100644 index 0000000000..c95b854434 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/Makefile @@ -0,0 +1,5 @@ +.PHONY: format + +format: + isort generate.py wan + yapf -i -r *.py generate.py wan diff --git a/MindIE/MultiModal/Wan2.2/README.md b/MindIE/MultiModal/Wan2.2/README.md new file mode 100644 index 0000000000..2460cee8f7 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/README.md @@ -0,0 +1,443 @@ +# Wan2.2推理指导 +## 一、准备运行环境 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.11.10 | - | + | torch | 2.1.0 | - | + +### 1.1 获取CANN&MindIE安装包&环境准备 +- 设备支持 +Atlas 800I/800T A2(8*64G)推理设备:支持的卡数最小为1 +[Atlas 800I/800T A2(8*64G)](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +### 1.2 CANN安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +### 1.3 环境依赖安装 +```shell +pip3 install -r requirements.txt +``` + +### 1.4 MindIE安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +### 1.5 Torch_npu安装 +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```shell +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` + +## 二、下载权重 + +### Wan2.2 权重及配置文件说明 + +- Huggingface + +| 模型 | 链接 | +| ------------ | ------------ | +| Wan2.2-T2V-A14B | [🤗huggingface](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B/tree/main) | +| Wan2.2-I2V-A14B | [🤗huggingface](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B/tree/main) | +| Wan2.2-TI2V-5B | [🤗huggingface](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B/tree/main) | + + +- Modelers + +| 模型 | 链接 | +| ------------ | ------------ | +| Wan2.2-T2V-A14B | [ Modelers](https://modelers.cn/models/Modelers_Park/Wan2.2-T2V-A14B) | +| Wan2.2-I2V-A14B | [ Modelers](https://modelers.cn/models/Modelers_Park/Wan2.2-I2V-A14B ) | +| Wan2.2-TI2V-5B | [ Modelers](https://modelers.cn/models/Modelers_Park/Wan2.2-TI2V-5B) | + + +## 三、Wan2.2使用 + +### 3.1 下载到本地 +```shell +git clone https://modelers.cn/MindIE/Wan2.2.git +``` + +### 3.2 Wan2.2-T2V-A14B +使用上一步下载的权重 +```shell +model_base="./Wan2.2-T2V-A14B/" +``` +#### 3.2.1 等价优化 +#### 3.2.1.1 8卡性能测试 +执行命令: +```shell +export ALGO=0 +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false + +torchrun --nproc_per_node=8 --master_port=23459 generate.py \ +--task t2v-A14B \ +--ckpt_dir ${model_base} \ +--size 1280*720 \ +--frame_num 81 \ +--sample_steps 40 \ +--dit_fsdp \ +--t5_fsdp \ +--cfg_size 1 \ +--ulysses_size 8 \ +--vae_parallel \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--base_seed 0 + +``` +参数说明: +- ALGO: 为0表示默认FA算子; 设置为1表示使用高性能FA算子 +- task: 任务类型。 +- ckpt_dir: 模型的权重路径 +- size: 生成视频的分辨率,支持(1280,720)、(832,480)分辨率 +- frame_num: 生成视频的帧数 +- sample_steps: 推理步数 +- dit_fsdp: dit使能fsdp, 用以降低显存占用 +- t5_fsdp: t5使能fsdp, 用以降低显存占用 +- cfg_size: cfg并行数 +- ulysses_size: ulysses并行数 +- vae_parallel: 使能vae并行策略 +- prompt: 文本提示词 +- base_seed: 随机种子 + +#### 3.2.1.2 16卡性能测试 +执行命令: +```shell +export ALGO=0 +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false + +torchrun --nproc_per_node=16 --master_port=23459 generate.py \ +--task t2v-A14B \ +--ckpt_dir ${model_base} \ +--size 1280*720 \ +--frame_num 81 \ +--sample_steps 40 \ +--dit_fsdp \ +--t5_fsdp \ +--cfg_size 2 \ +--ulysses_size 8 \ +--vae_parallel \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--base_seed 0 +``` + + +### 3.3 Wan2.2-I2V-A14B +使用上一步下载的权重 +```shell +model_base="./Wan2.2-I2V-A14B/" +``` + +#### 3.3.1 等价优化 +#### 3.3.1.1 8卡性能测试 + +执行命令: +```shell +export ALGO=0 +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false + +torchrun --nproc_per_node=8 generate.py \ +--task i2v-A14B \ +--ckpt_dir ${model_base} \ +--size 1280*720 \ +--frame_num 81 \ +--sample_steps 40 \ +--dit_fsdp \ +--t5_fsdp \ +--cfg_size 1 \ +--ulysses_size 8 \ +--vae_parallel \ +--image examples/i2v_input.JPG \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--base_seed 0 +``` +参数说明: +- ALGO: 为0表示默认FA算子; 设置为1表示使用高性能FA算子 +- task: 任务类型。 +- ckpt_dir: 模型的权重路径 +- size: 生成视频的分辨率,支持(1280,720)、(832,480)分辨率 +- frame_num: 生成视频的帧数 +- sample_steps: 推理步数 +- dit_fsdp: dit使能fsdp, 用以降低显存占用 +- t5_fsdp: t5使能fsdp, 用以降低显存占用 +- cfg_size: cfg并行数 +- ulysses_size: ulysses并行数 +- vae_parallel: 使能vae并行策略 +- image: 输入图片路径 +- prompt: 文本提示词 +- base_seed: 随机种子 + +#### 3.3.1.2 16卡性能测试 +执行命令: +```shell +export ALGO=0 +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false + +torchrun --nproc_per_node=16 --master_port=23459 generate.py \ +--task i2v-A14B \ +--ckpt_dir ${model_base} \ +--size 1280*720 \ +--frame_num 81 \ +--sample_steps 40 \ +--dit_fsdp \ +--t5_fsdp \ +--cfg_size 2 \ +--ulysses_size 8 \ +--vae_parallel \ +--image examples/i2v_input.JPG \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--base_seed 0 +``` + +### 3.4 Wan2.2-TI2V-5B +使用上一步下载的权重 +```shell +model_base="./Wan2.2-TI2V-5B/" +``` + +#### 3.4.1 等价优化 +#### 3.4.1.1 单卡性能测试 +执行命令: +```shell +export ALGO=0 +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false + +python generate.py \ +--task ti2v-5B \ +--ckpt_dir ${model_base} \ +--size 1280*704 \ +--frame_num 121 \ +--sample_steps 50 \ +--image examples/i2v_input.JPG \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--base_seed 0 +``` +参数说明: +- ALGO: 为0表示默认FA算子;设置为1表示使用高性能FA算子 +- task: 任务类型。 +- ckpt_dir: 模型的权重路径 +- size: 生成视频的分辨率,支持(1280,720)、(832,480)分辨率 +- frame_num: 生成视频的帧数 +- sample_steps: 推理步数 +- image: 输入图片路径 +- prompt: 文本提示词 +- base_seed: 随机种子 + + +#### 3.4.1.2 8卡性能测试 + +执行命令: +```shell +export ALGO=0 +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false +torchrun --nproc_per_node=8 generate.py \ +--task ti2v-5B \ +--ckpt_dir ${model_base} \ +--size 1280*704 \ +--frame_num 121 \ +--sample_steps 50 \ +--dit_fsdp \ +--t5_fsdp \ +--cfg_size 1 \ +--ulysses_size 8 \ +--image examples/i2v_input.JPG \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--base_seed 0 +``` +参数说明: +- ALGO: 为0表示默认FA算子;设置为1表示使用高性能FA算子 +- task: 任务类型。 +- ckpt_dir: 模型的权重路径 +- size: 生成视频的分辨率,支持(1280,720)、(832,480)分辨率 +- frame_num: 生成视频的帧数 +- sample_steps: 推理步数 +- dit_fsdp: dit使能fsdp, 用以降低显存占用 +- t5_fsdp: t5使能fsdp, 用以降低显存占用 +- cfg_size: cfg并行数 +- ulysses_size: ulysses并行数 +- vae_parallel: 使能vae并行策略 +- image: 输入图片路径 +- prompt: 文本提示词 +- base_seed: 随机种子 + +#### 3.4.1.3 16卡性能测试 +执行命令: +```shell +export ALGO=0 +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false + +torchrun --nproc_per_node=16 --master_port=23459 generate.py \ +--task ti2v-5B \ +--ckpt_dir ${model_base} \ +--size 1280*704 \ +--frame_num 81 \ +--sample_steps 40 \ +--dit_fsdp \ +--t5_fsdp \ +--cfg_size 2 \ +--ulysses_size 8 \ +--vae_parallel \ +--image examples/i2v_input.JPG \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--base_seed 0 +``` + +#### 3.4.2 算法优化 +#### 3.4.2.1 单卡性能测试 +执行命令: +```shell +export ALGO=0 +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false + +python generate.py \ +--task ti2v-5B \ +--ckpt_dir ${model_base} \ +--size 1280*704 \ +--frame_num 121 \ +--sample_steps 50 \ +--image examples/i2v_input.JPG \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--base_seed 0 \ +--use_attentioncache \ +--start_step 20 \ +--attentioncache_interval 2 \ +--end_step 47 +``` +参数说明: +- ALGO: 为0表示默认FA算子;设置为1表示使用高性能FA算子 +- use_attentioncache: 使能attentioncache策略 +- start_step: cache开始的step +- attentioncache_interval: cache重计算间隔 +- end_step: cache结束的step + +#### 3.4.2.2 8卡性能测试 + +执行命令: +```shell +export ALGO=0 +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false +torchrun --nproc_per_node=8 generate.py \ +--task ti2v-5B \ +--ckpt_dir ${model_base} \ +--size 1280*704 \ +--frame_num 121 \ +--sample_steps 50 \ +--dit_fsdp \ +--t5_fsdp \ +--cfg_size 1 \ +--ulysses_size 8 \ +--image examples/i2v_input.JPG \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--base_seed 0 \ +--use_attentioncache \ +--start_step 20 \ +--attentioncache_interval 2 \ +--end_step 47 +``` + +#### 3.4.2.1 16卡性能测试 +执行命令: +```shell +export ALGO=0 +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false + +torchrun --nproc_per_node=16 --master_port=23459 generate.py \ +--task ti2v-5B \ +--ckpt_dir ${model_base} \ +--size 1280*704 \ +--frame_num 81 \ +--sample_steps 40 \ +--dit_fsdp \ +--t5_fsdp \ +--cfg_size 2 \ +--ulysses_size 8 \ +--vae_parallel \ +--image examples/i2v_input.JPG \ +--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ +--base_seed 0 \ +--use_attentioncache \ +--start_step 20 \ +--attentioncache_interval 2 \ +--end_step 47 +``` + + +注: +1. 若出现OOM, 可添加环境变量 `export T5_LOAD_CPU=1`,以降低显存占用 +2. 当前仅TI2V支持attentioncache + + +## 四、推理结果参考 +### Atlas 800I A2(8*64G)性能数据 + +| 模型 | cpu规格 | 规格 | 迭代次数 | E2E耗时(ALGO=0) | E2E耗时(ALGO=1) | +| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | +| Wan2.2-T2V-A14B | 64核(arm) | 81×1280×720 | 40 | 576.04s | 435.99s | +| Wan2.2-I2V-A14B | 64核(arm) | 81×1280×720 | 40 | 577.79s | 436.42s | +| Wan2.2-TI2V-5B | 64核(arm) | 121×1280×704 | 50 | 91.92s | 84.14s | + + +## 声明 +- 本代码仓提到的数据集和模型仅作为示例,这些数据集和模型仅供您用于非商业目的,如您使用这些数据集和模型来完成示例,请您特别注意应遵守对应数据集和模型的License,如您因使用数据集或模型而产生侵权纠纷,华为不承担任何责任。 +- 如您在使用本代码仓的过程中,发现任何问题(包括但不限于功能问题、合规问题),请在本代码仓提交issue,我们将及时审视并解答。 diff --git a/MindIE/MultiModal/Wan2.2/generate.py b/MindIE/MultiModal/Wan2.2/generate.py new file mode 100644 index 0000000000..e5a514ca38 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/generate.py @@ -0,0 +1,657 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +import logging +import os +import sys +import warnings +from datetime import datetime + +warnings.filterwarnings('ignore') + +import random +import time + +import torch +import torch_npu +torch_npu.npu.set_compile_mode(jit_compile=False) +torch.npu.config.allow_internal_format=False +from torch_npu.contrib import transfer_to_npu + +import torch.distributed as dist +from PIL import Image + +import wan +from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS +from wan.distributed.util import init_distributed_group +from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander +from wan.utils.utils import save_video, str2bool +from wan.distributed.parallel_mgr import ParallelConfig, init_parallel_env, finalize_parallel_env +from wan.distributed.tp_applicator import TensorParallelApplicator + +from mindiesd import CacheConfig, CacheAgent + +EXAMPLE_PROMPT = { + "t2v-A14B": { + "prompt": + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "i2v-A14B": { + "prompt": + "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + "image": + "examples/i2v_input.JPG", + }, + "ti2v-5B": { + "prompt": + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, +} + + +def _validate_args(args): + # Basic check + assert args.ckpt_dir is not None, "Please specify the checkpoint directory." + assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" + assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" + + if args.prompt is None: + args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] + if args.image is None and "image" in EXAMPLE_PROMPT[args.task]: + args.image = EXAMPLE_PROMPT[args.task]["image"] + + if args.task == "i2v-A14B": + assert args.image is not None, "Please specify the image path for i2v." + + cfg = WAN_CONFIGS[args.task] + + if args.sample_steps is None: + args.sample_steps = cfg.sample_steps + + if args.sample_shift is None: + args.sample_shift = cfg.sample_shift + + if args.sample_guide_scale is None: + args.sample_guide_scale = cfg.sample_guide_scale + + if args.frame_num is None: + args.frame_num = cfg.frame_num + + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( + 0, sys.maxsize) + # Size check + assert args.size in SUPPORTED_SIZES[ + args. + task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate a image or video from a text prompt or image using Wan" + ) + parser.add_argument( + "--task", + type=str, + default="t2v-A14B", + choices=list(WAN_CONFIGS.keys()), + help="The task to run.") + parser.add_argument( + "--size", + type=str, + default="1280*720", + choices=list(SIZE_CONFIGS.keys()), + help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." + ) + parser.add_argument( + "--frame_num", + type=int, + default=None, + help="How many frames of video are generated. The number should be 4n+1" + ) + parser.add_argument( + "--ckpt_dir", + type=str, + default=None, + help="The path to the checkpoint directory.") + parser.add_argument( + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." + ) + parser.add_argument( + "--cfg_size", + type=int, + default=1, + help="The size of the cfg parallelism in DiT.") + parser.add_argument( + "--ulysses_size", + type=int, + default=1, + help="The size of the ulysses parallelism in DiT.") + parser.add_argument( + "--ring_size", + type=int, + default=1, + help="The size of the ring attention parallelism in DiT.") + parser.add_argument( + "--tp_size", + type=int, + default=1, + help="The size of the tensor parallelism in DiT.") + parser.add_argument( + "--vae_parallel", + action="store_true", + default=False, + help="Whether to use parallel for vae.") + parser.add_argument( + "--t5_fsdp", + action="store_true", + default=False, + help="Whether to use FSDP for T5.") + parser.add_argument( + "--t5_cpu", + action="store_true", + default=False, + help="Whether to place T5 model on CPU.") + parser.add_argument( + "--dit_fsdp", + action="store_true", + default=False, + help="Whether to use FSDP for DiT.") + parser.add_argument( + "--save_file", + type=str, + default=None, + help="The file to save the generated video to.") + parser.add_argument( + "--prompt", + type=str, + default=None, + help="The prompt to generate the video from.") + parser.add_argument( + "--use_prompt_extend", + action="store_true", + default=False, + help="Whether to use prompt extend.") + parser.add_argument( + "--prompt_extend_method", + type=str, + default="local_qwen", + choices=["dashscope", "local_qwen"], + help="The prompt extend method to use.") + parser.add_argument( + "--prompt_extend_model", + type=str, + default=None, + help="The prompt extend model to use.") + parser.add_argument( + "--prompt_extend_target_lang", + type=str, + default="zh", + choices=["zh", "en"], + help="The target language of prompt extend.") + parser.add_argument( + "--base_seed", + type=int, + default=-1, + help="The seed to use for generating the video.") + parser.add_argument( + "--image", + type=str, + default=None, + help="The image to generate the video from.") + parser.add_argument( + "--sample_solver", + type=str, + default='unipc', + choices=['unipc', 'dpm++'], + help="The solver used to sample.") + parser.add_argument( + "--sample_steps", type=int, default=None, help="The sampling steps.") + parser.add_argument( + "--sample_shift", + type=float, + default=None, + help="Sampling shift factor for flow matching schedulers.") + parser.add_argument( + "--sample_guide_scale", + type=float, + default=None, + help="Classifier free guidance scale.") + parser.add_argument( + "--convert_model_dtype", + action="store_true", + default=False, + help="Whether to convert model paramerters dtype.") + parser = add_attentioncache_args(parser) + args = parser.parse_args() + + _validate_args(args) + + return args + + +def add_attentioncache_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group(title="Attention Cache args") + + group.add_argument("--use_attentioncache", action='store_true') + group.add_argument("--attentioncache_ratio", type=float, default=1.2) + group.add_argument("--attentioncache_interval", type=int, default=4) + group.add_argument("--start_step", type=int, default=12) + group.add_argument("--end_step", type=int, default=37) + + return parser + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)]) + else: + logging.basicConfig(level=logging.ERROR) + + +def generate(args): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = local_rank + _init_logging(rank) + stream = torch.npu.Stream() + + if args.offload_model is None: + args.offload_model = False if world_size > 1 else True + logging.info( + f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="hccl", + init_method="env://", + rank=rank, + world_size=world_size) + else: + assert not ( + args.t5_fsdp or args.dit_fsdp + ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." + assert not ( + args.cfg_size > 1 or args.ulysses_size > 1 or args.ring_size > 1 or args.tp_size > 1 + ), f"context parallel are not supported in non-distributed environments." + assert not ( + args.vae_parallel + ), f"vae parallel are not supported in non-distributed environments." + + if args.tp_size > 1: + raise NotImplementedError("Tensor Parallel is not supported now") + if "ti2v" not in args.task and args.use_attentioncache: + raise NotImplementedError(f"{args.task} unsupport attentioncache now") + + if args.cfg_size > 1 or args.ulysses_size > 1 or args.ring_size > 1 or args.tp_size > 1: + assert args.cfg_size * args.ulysses_size * args.ring_size * args.tp_size == world_size, f"The number of cfg_size, ulysses_size, ring_size and tp_size should be equal to the world size." + sp_degree = args.ulysses_size * args.ring_size + parallel_config = ParallelConfig( + sp_degree=sp_degree, + ulysses_degree=args.ulysses_size, + ring_degree=args.ring_size, + tp_degree=args.tp_size, + use_cfg_parallel=(args.cfg_size==2), + world_size=world_size, + ) + init_parallel_env(parallel_config) + + if args.tp_size > 1 and args.dit_fsdp: + logging.info("DiT using Tensor Parallel, disabled dit_fsdp") + args.dit_fsdp = False + + if args.use_prompt_extend: + if args.prompt_extend_method == "dashscope": + prompt_expander = DashScopePromptExpander( + model_name=args.prompt_extend_model, + task=args.task, + is_vl=args.image is not None) + elif args.prompt_extend_method == "local_qwen": + prompt_expander = QwenPromptExpander( + model_name=args.prompt_extend_model, + task=args.task, + is_vl=args.image is not None, + device=rank) + else: + raise NotImplementedError( + f"Unsupport prompt_extend_method: {args.prompt_extend_method}") + + cfg = WAN_CONFIGS[args.task] + if args.ulysses_size > 1: + assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`." + + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {cfg}") + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + logging.info(f"Input prompt: {args.prompt}") + img = None + if args.image is not None: + img = Image.open(args.image).convert("RGB") + logging.info(f"Input image: {args.image}") + + # prompt extend + if args.use_prompt_extend: + logging.info("Extending prompt ...") + if rank == 0: + prompt_output = prompt_expander( + args.prompt, + image=img, + tar_lang=args.prompt_extend_target_lang, + seed=args.base_seed) + if prompt_output.status == False: + logging.info( + f"Extending prompt failed: {prompt_output.message}") + logging.info("Falling back to original prompt.") + input_prompt = args.prompt + else: + input_prompt = prompt_output.prompt + input_prompt = [input_prompt] + else: + input_prompt = [None] + if dist.is_initialized(): + dist.broadcast_object_list(input_prompt, src=0) + args.prompt = input_prompt[0] + logging.info(f"Extended prompt: {args.prompt}") + + if "t2v" in args.task: + logging.info("Creating WanT2V pipeline.") + wan_t2v = wan.WanT2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_sp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + convert_model_dtype=args.convert_model_dtype, + use_vae_parallel=args.vae_parallel + ) + + transformer_low = wan_t2v.low_noise_model + transformer_high = wan_t2v.high_noise_model + + if args.tp_size > 1: + logging.info("Initializing Tensor Parallel ...") + applicator = TensorParallelApplicator(args.tp_size, device_map="cpu") + applicator.apply_to_model(transformer_low) + applicator.apply_to_model(transformer_high) + # wan_t2v.low_noise_model.to("npu") + # wan_t2v.high_noise_model.to("npu") + + if args.use_attentioncache: + config_high = CacheConfig( + method="attention_cache", + blocks_count=len(transformer_high.blocks), + steps_count=args.sample_steps, + step_start=args.start_step, + step_interval=args.attentioncache_interval, + step_end=args.end_step + ) + else: + config_high = CacheConfig( + method="attention_cache", + blocks_count=len(transformer_high.blocks), + steps_count=args.sample_steps + ) + config_low = CacheConfig( + method="attention_cache", + blocks_count=len(transformer_low.blocks), + steps_count=args.sample_steps + ) + cache_high = CacheAgent(config_high) + cache_low = CacheAgent(config_low) + + if args.dit_fsdp: + for block in transformer_high._fsdp_wrapped_module.blocks: + block._fsdp_wrapped_module.cache = cache_high + block._fsdp_wrapped_module.args = args + for block in transformer_low._fsdp_wrapped_module.blocks: + block._fsdp_wrapped_module.cache = cache_low + block._fsdp_wrapped_module.args = args + else: + for block in transformer_high.blocks: + block.cache = cache_high + block.args = args + for block in transformer_low.blocks: + block.cache = cache_low + block.args = args + + logging.info("Warm up 2 steps ...") + video = wan_t2v.generate( + args.prompt, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=2, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + logging.info(f"Generating video ...") + stream.synchronize() + begin = time.time() + video = wan_t2v.generate( + args.prompt, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + stream.synchronize() + end = time.time() + logging.info(f"Generating video used time {end - begin: .4f}s") + + elif "ti2v" in args.task: + logging.info("Creating WanTI2V pipeline.") + wan_ti2v = wan.WanTI2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_sp=(args.ulysses_size > 1), + t5_cpu=args.t5_cpu, + convert_model_dtype=args.convert_model_dtype, + ) + + transformer = wan_ti2v.model + if args.tp_size > 1: + logging.info("Initializing Tensor Parallel ...") + applicator = TensorParallelApplicator(args.tp_size, device_map="cpu") + applicator.apply_to_model(transformer) + # wan_ti2v.model.to("npu") + + if args.use_attentioncache: + config = CacheConfig( + method="attention_cache", + blocks_count=len(transformer.blocks), + steps_count=args.sample_steps, + step_start=args.start_step, + step_interval=args.attentioncache_interval, + step_end=args.end_step + ) + else: + config = CacheConfig( + method="attention_cache", + blocks_count=len(transformer.blocks), + steps_count=args.sample_steps + ) + cache = CacheAgent(config) + if args.dit_fsdp: + for block in transformer._fsdp_wrapped_module.blocks: + block._fsdp_wrapped_module.cache = cache + block._fsdp_wrapped_module.args = args + else: + for block in transformer.blocks: + block.cache = cache + block.args = args + + logging.info("Warm up 2 steps ...") + video = wan_ti2v.generate( + args.prompt, + img=img, + size=SIZE_CONFIGS[args.size], + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=2, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + logging.info(f"Generating video ...") + stream.synchronize() + begin = time.time() + video = wan_ti2v.generate( + args.prompt, + img=img, + size=SIZE_CONFIGS[args.size], + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + stream.synchronize() + end = time.time() + logging.info(f"Generating video used time {end - begin: .4f}s") + else: + logging.info("Creating WanI2V pipeline.") + wan_i2v = wan.WanI2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_sp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + convert_model_dtype=args.convert_model_dtype, + use_vae_parallel=args.vae_parallel + ) + + transformer_low = wan_i2v.low_noise_model + transformer_high = wan_i2v.high_noise_model + if args.tp_size > 1: + logging.info("Initializing Tensor Parallel ...") + applicator = TensorParallelApplicator(args.tp_size, device_map="cpu") + applicator.apply_to_model(transformer_low) + applicator.apply_to_model(transformer_high) + # wan_i2v.low_noise_model.to("npu") + # wan_i2v.high_noise_model.to("npu") + + if args.use_attentioncache: + config_low = CacheConfig( + method="attention_cache", + blocks_count=len(transformer_low.blocks), + steps_count=args.sample_steps, + step_start=args.start_step, + step_interval=args.attentioncache_interval, + step_end=args.end_step + ) + else: + config_low = CacheConfig( + method="attention_cache", + blocks_count=len(transformer_low.blocks), + steps_count=args.sample_steps + ) + config_high = CacheConfig( + method="attention_cache", + blocks_count=len(transformer_high.blocks), + steps_count=args.sample_steps + ) + cache_low = CacheAgent(config_low) + cache_high = CacheAgent(config_high) + + if args.dit_fsdp: + for block in transformer_high._fsdp_wrapped_module.blocks: + block._fsdp_wrapped_module.cache = cache_high + block._fsdp_wrapped_module.args = args + for block in transformer_low._fsdp_wrapped_module.blocks: + block._fsdp_wrapped_module.cache = cache_low + block._fsdp_wrapped_module.args = args + else: + for block in transformer_high.blocks: + block.cache = cache_high + block.args = args + for block in transformer_low.blocks: + block.cache = cache_low + block.args = args + + logging.info("Warm up 2 steps ...") + video = wan_i2v.generate( + args.prompt, + img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=2, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + logging.info("Generating video ...") + stream.synchronize() + begin = time.time() + video = wan_i2v.generate( + args.prompt, + img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + stream.synchronize() + end = time.time() + logging.info(f"Generating video used time {end - begin: .4f}s") + + if rank == 0: + if args.save_file is None: + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = args.prompt.replace(" ", "_").replace("/", + "_")[:50] + suffix = '.mp4' + args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.cfg_size}_{args.ulysses_size}_{args.ring_size}_{args.tp_size}_{formatted_prompt}_{formatted_time}" + suffix + + logging.info(f"Saving generated video to {args.save_file}") + save_video( + tensor=video[None], + save_file=args.save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + del video + + finalize_parallel_env() + logging.info("Finished.") + + +if __name__ == "__main__": + args = _parse_args() + generate(args) diff --git a/MindIE/MultiModal/Wan2.2/pyproject.toml b/MindIE/MultiModal/Wan2.2/pyproject.toml new file mode 100644 index 0000000000..337240afa9 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/pyproject.toml @@ -0,0 +1,66 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "wan" +version = "2.2.0" +description = "Wan: Open and Advanced Large-Scale Video Generative Models" +authors = [ + { name = "Wan Team", email = "wan.ai@alibabacloud.com" } +] +license = { file = "LICENSE.txt" } +readme = "README.md" +requires-python = ">=3.10,<4.0" +dependencies = [ + "torch>=2.4.0", + "torchvision>=0.19.0", + "opencv-python>=4.9.0.80", + "diffusers>=0.31.0", + "transformers>=4.49.0", + "tokenizers>=0.20.3", + "accelerate>=1.1.1", + "tqdm", + "imageio", + "easydict", + "ftfy", + "dashscope", + "imageio-ffmpeg", + "flash_attn", + "numpy>=1.23.5,<2" +] + +[project.optional-dependencies] +dev = [ + "pytest", + "black", + "flake8", + "isort", + "mypy", + "huggingface-hub[cli]" +] + +[project.urls] +homepage = "https://wanxai.com" +documentation = "https://github.com/Wan-Video/Wan2.2" +repository = "https://github.com/Wan-Video/Wan2.2" +huggingface = "https://huggingface.co/Wan-AI/" +modelscope = "https://modelscope.cn/organization/Wan-AI" +discord = "https://discord.gg/p5XbdQV7" + +[tool.setuptools] +packages = ["wan"] + +[tool.setuptools.package-data] +"wan" = ["**/*.py"] + +[tool.black] +line-length = 88 + +[tool.isort] +profile = "black" + +[tool.mypy] +strict = true + + diff --git a/MindIE/MultiModal/Wan2.2/requirements.txt b/MindIE/MultiModal/Wan2.2/requirements.txt new file mode 100644 index 0000000000..04b399d9c3 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/requirements.txt @@ -0,0 +1,15 @@ +torch==2.1.0 +torchvision>=0.16.0 +opencv-python>=4.9.0.80 +diffusers>=0.31.0 +transformers>=4.49.0 +tokenizers>=0.20.3 +accelerate>=1.1.1 +tqdm +imageio[ffmpeg] +easydict +ftfy +dashscope +imageio-ffmpeg +numpy>=1.23.5,<2 +yunchang==0.6.0 \ No newline at end of file diff --git a/MindIE/MultiModal/Wan2.2/tests/README.md b/MindIE/MultiModal/Wan2.2/tests/README.md new file mode 100644 index 0000000000..55019f90c9 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/tests/README.md @@ -0,0 +1,6 @@ + +Put all your models (Wan2.2-T2V-A14B, Wan2.2-I2V-A14B, Wan2.2-TI2V-5B) in a folder and specify the max GPU number you want to use. + +```bash +bash ./tests/test.sh +``` diff --git a/MindIE/MultiModal/Wan2.2/tests/test.sh b/MindIE/MultiModal/Wan2.2/tests/test.sh new file mode 100644 index 0000000000..621eb253e0 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/tests/test.sh @@ -0,0 +1,91 @@ +#!/bin/bash +set -x + +unset NCCL_DEBUG + +if [ "$#" -eq 2 ]; then + MODEL_DIR=$(realpath "$1") + GPUS=$2 +else + echo "Usage: $0 " + exit 1 +fi + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +REPO_ROOT="$(dirname "$SCRIPT_DIR")" +cd "$REPO_ROOT" || exit 1 + +PY_FILE=./generate.py + + +function t2v_A14B() { + CKPT_DIR="$MODEL_DIR/Wan2.2-T2V-A14B" + + # # 1-GPU Test + # echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B 1-GPU Test: " + # python $PY_FILE --task t2v-A14B --size 480*832 --ckpt_dir $CKPT_DIR + + # Multiple GPU Test + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B Multiple GPU Test: " + torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-A14B --ckpt_dir $CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS + + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B Multiple GPU Test: " + torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-A14B --ckpt_dir $CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS + + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B Multiple GPU Test: " + torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-A14B --ckpt_dir $CKPT_DIR --size 1280*720 --dit_fsdp --t5_fsdp --ulysses_size $GPUS + + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B Multiple GPU, prompt extend local_qwen: " + torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-A14B --ckpt_dir $CKPT_DIR --size 480*832 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en" +} + + +function i2v_A14B() { + CKPT_DIR="$MODEL_DIR/Wan2.2-I2V-A14B" + + # echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: " + # python $PY_FILE --task i2v-A14B --size 832*480 --ckpt_dir $CKPT_DIR + + # Multiple GPU Test + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: " + torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-A14B --ckpt_dir $CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS + + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: " + torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-A14B --ckpt_dir $CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en" + + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: " + torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-A14B --ckpt_dir $CKPT_DIR --size 1280*720 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en" + + if [ -n "${DASH_API_KEY+x}" ]; then + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: " + torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-A14B --ckpt_dir $CKPT_DIR --size 480*832 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope" + else + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test." + fi +} + +function ti2v_5B() { + CKPT_DIR="$MODEL_DIR/Wan2.2-TI2V-5B" + + # # 1-GPU Test + # echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B t2v 1-GPU Test: " + # python $PY_FILE --task ti2v-5B --size 1280*704 --ckpt_dir $CKPT_DIR + + # Multiple GPU Test + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B t2v Multiple GPU Test: " + torchrun --nproc_per_node=$GPUS $PY_FILE --task ti2v-5B --ckpt_dir $CKPT_DIR --size 1280*704 --dit_fsdp --t5_fsdp --ulysses_size $GPUS + + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B t2v Multiple GPU, prompt extend local_qwen: " + torchrun --nproc_per_node=$GPUS $PY_FILE --task ti2v-5B --ckpt_dir $CKPT_DIR --size 704*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en" + + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B i2v Multiple GPU Test: " + torchrun --nproc_per_node=$GPUS $PY_FILE --task ti2v-5B --ckpt_dir $CKPT_DIR --size 704*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --image "examples/i2v_input.JPG" + + echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B i2v Multiple GPU, prompt extend local_qwen: " + torchrun --nproc_per_node=$GPUS $PY_FILE --task ti2v-5B --ckpt_dir $CKPT_DIR --size 1280*704 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang 'en' --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --image "examples/i2v_input.JPG" + +} + +t2v_A14B +i2v_A14B +ti2v_5B diff --git a/MindIE/MultiModal/Wan2.2/wan/__init__.py b/MindIE/MultiModal/Wan2.2/wan/__init__.py new file mode 100644 index 0000000000..0861d669fe --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from . import configs, distributed, modules +from .image2video import WanI2V +from .text2video import WanT2V +from .textimage2video import WanTI2V diff --git a/MindIE/MultiModal/Wan2.2/wan/configs/__init__.py b/MindIE/MultiModal/Wan2.2/wan/configs/__init__.py new file mode 100644 index 0000000000..8311a9c7f1 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/configs/__init__.py @@ -0,0 +1,43 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import copy +import os + +os.environ['TOKENIZERS_PARALLELISM'] = 'false' + +from .wan_i2v_A14B import i2v_A14B +from .wan_t2v_A14B import t2v_A14B +from .wan_ti2v_5B import ti2v_5B + +WAN_CONFIGS = { + 't2v-A14B': t2v_A14B, + 'i2v-A14B': i2v_A14B, + 'ti2v-5B': ti2v_5B, +} + +SIZE_CONFIGS = { + '720*1280': (720, 1280), + '1280*720': (1280, 720), + '480*832': (480, 832), + '832*480': (832, 480), + '704*1280': (704, 1280), + '1280*704': (1280, 704), + '432*768': (432, 768), + '768*432': (768, 432) +} + +MAX_AREA_CONFIGS = { + '720*1280': 720 * 1280, + '1280*720': 1280 * 720, + '480*832': 480 * 832, + '832*480': 832 * 480, + '704*1280': 704 * 1280, + '1280*704': 1280 * 704, + '432*768': 432 * 768, + '768*432': 768 * 432 +} + +SUPPORTED_SIZES = { + 't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480', '432*768', '768*432'), + 'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480', '432*768', '768*432'), + 'ti2v-5B': ('704*1280', '1280*704', '480*832', '832*480', '432*768', '768*432'), +} diff --git a/MindIE/MultiModal/Wan2.2/wan/configs/shared_config.py b/MindIE/MultiModal/Wan2.2/wan/configs/shared_config.py new file mode 100644 index 0000000000..c58ab04ff9 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/configs/shared_config.py @@ -0,0 +1,20 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +from easydict import EasyDict + +#------------------------ Wan shared config ------------------------# +wan_shared_cfg = EasyDict() + +# t5 +wan_shared_cfg.t5_model = 'umt5_xxl' +wan_shared_cfg.t5_dtype = torch.bfloat16 +wan_shared_cfg.text_len = 512 + +# transformer +wan_shared_cfg.param_dtype = torch.bfloat16 + +# inference +wan_shared_cfg.num_train_timesteps = 1000 +wan_shared_cfg.sample_fps = 16 +wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' +wan_shared_cfg.frame_num = 81 diff --git a/MindIE/MultiModal/Wan2.2/wan/configs/wan_i2v_A14B.py b/MindIE/MultiModal/Wan2.2/wan/configs/wan_i2v_A14B.py new file mode 100644 index 0000000000..f654cc6b24 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/configs/wan_i2v_A14B.py @@ -0,0 +1,37 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan I2V A14B ------------------------# + +i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B') +i2v_A14B.update(wan_shared_cfg) + +i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +i2v_A14B.t5_tokenizer = 'google/umt5-xxl' + +# vae +i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth' +i2v_A14B.vae_stride = (4, 8, 8) + +# transformer +i2v_A14B.patch_size = (1, 2, 2) +i2v_A14B.dim = 5120 +i2v_A14B.ffn_dim = 13824 +i2v_A14B.freq_dim = 256 +i2v_A14B.num_heads = 40 +i2v_A14B.num_layers = 40 +i2v_A14B.window_size = (-1, -1) +i2v_A14B.qk_norm = True +i2v_A14B.cross_attn_norm = True +i2v_A14B.eps = 1e-6 +i2v_A14B.low_noise_checkpoint = 'low_noise_model' +i2v_A14B.high_noise_checkpoint = 'high_noise_model' + +# inference +i2v_A14B.sample_shift = 5.0 +i2v_A14B.sample_steps = 40 +i2v_A14B.boundary = 0.900 +i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise diff --git a/MindIE/MultiModal/Wan2.2/wan/configs/wan_t2v_A14B.py b/MindIE/MultiModal/Wan2.2/wan/configs/wan_t2v_A14B.py new file mode 100644 index 0000000000..a5220a5243 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/configs/wan_t2v_A14B.py @@ -0,0 +1,37 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan T2V A14B ------------------------# + +t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B') +t2v_A14B.update(wan_shared_cfg) + +# t5 +t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +t2v_A14B.t5_tokenizer = 'google/umt5-xxl' + +# vae +t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth' +t2v_A14B.vae_stride = (4, 8, 8) + +# transformer +t2v_A14B.patch_size = (1, 2, 2) +t2v_A14B.dim = 5120 +t2v_A14B.ffn_dim = 13824 +t2v_A14B.freq_dim = 256 +t2v_A14B.num_heads = 40 +t2v_A14B.num_layers = 40 +t2v_A14B.window_size = (-1, -1) +t2v_A14B.qk_norm = True +t2v_A14B.cross_attn_norm = True +t2v_A14B.eps = 1e-6 +t2v_A14B.low_noise_checkpoint = 'low_noise_model' +t2v_A14B.high_noise_checkpoint = 'high_noise_model' + +# inference +t2v_A14B.sample_shift = 12.0 +t2v_A14B.sample_steps = 40 +t2v_A14B.boundary = 0.875 +t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise diff --git a/MindIE/MultiModal/Wan2.2/wan/configs/wan_ti2v_5B.py b/MindIE/MultiModal/Wan2.2/wan/configs/wan_ti2v_5B.py new file mode 100644 index 0000000000..d5d5aed0d5 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/configs/wan_ti2v_5B.py @@ -0,0 +1,36 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan TI2V 5B ------------------------# + +ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B') +ti2v_5B.update(wan_shared_cfg) + +# t5 +ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +ti2v_5B.t5_tokenizer = 'google/umt5-xxl' + +# vae +ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth' +ti2v_5B.vae_stride = (4, 16, 16) + +# transformer +ti2v_5B.patch_size = (1, 2, 2) +ti2v_5B.dim = 3072 +ti2v_5B.ffn_dim = 14336 +ti2v_5B.freq_dim = 256 +ti2v_5B.num_heads = 24 +ti2v_5B.num_layers = 30 +ti2v_5B.window_size = (-1, -1) +ti2v_5B.qk_norm = True +ti2v_5B.cross_attn_norm = True +ti2v_5B.eps = 1e-6 + +# inference +ti2v_5B.sample_fps = 24 +ti2v_5B.sample_shift = 5.0 +ti2v_5B.sample_steps = 50 +ti2v_5B.sample_guide_scale = 5.0 +ti2v_5B.frame_num = 121 diff --git a/MindIE/MultiModal/Wan2.2/wan/distributed/__init__.py b/MindIE/MultiModal/Wan2.2/wan/distributed/__init__.py new file mode 100644 index 0000000000..566f71edb8 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/distributed/__init__.py @@ -0,0 +1 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. diff --git a/MindIE/MultiModal/Wan2.2/wan/distributed/comm.py b/MindIE/MultiModal/Wan2.2/wan/distributed/comm.py new file mode 100644 index 0000000000..287241e5b4 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/distributed/comm.py @@ -0,0 +1,95 @@ +import torch + +import torch.distributed as dist + + +def all_to_all_4D( + input_: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False +) -> torch.tensor: + """ + all-to-all for QKV + + Args: + input_ (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group : torch process group + use_sync (bool): whether to synchronize after all-to-all + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) + """ + assert ( + input_.dim() == 4 + ), f"input_ must be 4D tensor, got {input_.dim()} and shape {input_.shape}" + + seq_world_size = dist.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + # input_ (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) + bs, shard_seqlen, hc, hs = input_.shape + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) + input_t = ( + input_.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs) + .transpose(0, 2) + .contiguous() + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head + + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + torch.npu.synchronize() + else: + output = input_t + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) + + return output + + elif scatter_idx == 1 and gather_idx == 2: + # input_ (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, shard_hc, hs = input_.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) + input_t = ( + input_.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) + .transpose(0, 3) + .transpose(0, 1) + .contiguous() + .reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs) + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + torch.npu.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) + + return output + else: + raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") diff --git a/MindIE/MultiModal/Wan2.2/wan/distributed/fsdp.py b/MindIE/MultiModal/Wan2.2/wan/distributed/fsdp.py new file mode 100644 index 0000000000..6bb496d445 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/distributed/fsdp.py @@ -0,0 +1,43 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +from functools import partial + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy +from torch.distributed.utils import _free_storage + + +def shard_model( + model, + device_id, + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + process_group=None, + sharding_strategy=ShardingStrategy.FULL_SHARD, + sync_module_states=True, +): + model = FSDP( + module=model, + process_group=process_group, + sharding_strategy=sharding_strategy, + auto_wrap_policy=partial( + lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), + mixed_precision=MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype), + device_id=device_id, + sync_module_states=sync_module_states) + return model + + +def free_model(model): + for m in model.modules(): + if isinstance(m, FSDP): + _free_storage(m._handle.flat_param.data) + del model + gc.collect() + torch.cuda.empty_cache() diff --git a/MindIE/MultiModal/Wan2.2/wan/distributed/group_coordinator.py b/MindIE/MultiModal/Wan2.2/wan/distributed/group_coordinator.py new file mode 100644 index 0000000000..796bfe1a20 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/distributed/group_coordinator.py @@ -0,0 +1,597 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple, Union +import pickle + +import torch +import torch_npu +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +import logging + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + if "%" in key: + logging.error( + "Avoid having '%' in key " + "as it is used as a separator for nested entries." + ) + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "npu:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (prefix + key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%" + ) + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and npu graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + if torch.npu.is_available(): + self.device = torch.device(f"npu:{local_rank}") + else: + self.device = torch.device("cpu") + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + ) -> Union[torch.Tensor, List[torch.Tensor]]: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty( + input_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape([world_size, ] + input_size) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + + metadata_list: List[Tuple[Any, Any]] = [] + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, src: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + ) + self.ulysses_group = kwargs.get("ulysses_group", None) + self.ulysses_world_size = torch.distributed.get_world_size(self.ulysses_group) + self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group) + + self.ring_group = kwargs.get("ring_group", None) + self.ring_world_size = torch.distributed.get_world_size(self.ring_group) + self.ring_rank = torch.distributed.get_rank(self.ring_group) \ No newline at end of file diff --git a/MindIE/MultiModal/Wan2.2/wan/distributed/parallel_mgr.py b/MindIE/MultiModal/Wan2.2/wan/distributed/parallel_mgr.py new file mode 100644 index 0000000000..bc7084c79a --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/distributed/parallel_mgr.py @@ -0,0 +1,342 @@ +import os +from typing import List, Optional +from dataclasses import dataclass +import torch.distributed as dist +import torch_npu +import logging +from .util import RankGenerator, generate_masked_orthogonal_rank_groups +from .group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator + +from yunchang import set_seq_parallel_pg +from yunchang.globals import PROCESS_GROUP + +_WORLD: Optional[GroupCoordinator] = None +_TP: Optional[GroupCoordinator] = None +_SP: Optional[SequenceParallelGroupCoordinator] = None +_CFG: Optional[GroupCoordinator] = None + + +@dataclass +class ParallelConfig: + tp_degree: int = 1 + sp_degree: int = 1 + ulysses_degree: int = 1 + ring_degree: int = 1 + use_cfg_parallel: bool = False + world_size: int = 1 + + def __post_init__(self): + if self.use_cfg_parallel: + self.cfg_degree = 2 + else: + self.cfg_degree = 1 + if not self.tp_degree * self.sp_degree * self.cfg_degree <= self.world_size: + logging.error( + "tp_degree * sp_degree * cfg_degree must be less than or equal to " + "world_size because of classifier free guidance" + ) + if not (self.world_size % (self.tp_degree * self.sp_degree * self.cfg_degree) == 0): + logging.error("world_size must be divisible by tp_degree * sp_degree * cfg_degree") + + +# * QUERY +def get_world_group() -> GroupCoordinator: + if _WORLD is None: + logging.error("world group is not initialized") + return _WORLD + + +# TP +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +# SP +def get_sp_group() -> SequenceParallelGroupCoordinator: + if _SP is None: + logging.error("pipeline model parallel group is not initialized") + return _SP + + +def get_sequence_parallel_state(): + """Return state for the sequence parallel group.""" + return _SP is not None + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 1 + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 0 + return get_sp_group().rank_in_group + + +# CFG +def get_cfg_group() -> GroupCoordinator: + if _CFG is None: + logging.error("classifier_free_guidance parallel group is not initialized") + return _CFG + + +def get_cfg_state(): + """Return state for the sequence parallel group.""" + return _CFG is not None + + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 1 + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 0 + return get_cfg_group().rank_in_group + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "hccl", +): + logging.debug( + "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not dist.is_initialized(): + if distributed_init_method is None: + logging.error( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + # this backend is used for WORLD + dist.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = int(os.getenv('LOCAL_RANK', 0)) + torch_npu.npu.set_device(local_rank) + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(dist.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + if not _WORLD.world_size == dist.get_world_size(): + logging.error("world group already initialized with a different world size") + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _CFG is not None + and _SP is not None + and _TP is not None + ) + + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + if parallel_mode not in [ + "tensor", + "sequence", + "classifier_free_guidance", + ]: + logging.error(f"parallel_mode {parallel_mode} is not supported") + if parallel_mode == "sequence": + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + **kwargs, + ) + else: + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def initialize_model_parallel( + classifier_free_guidance_degree: int = 1, + sequence_parallel_degree: int = 1, + ulysses_degree: int = 1, + ring_degree: int = 1, + tensor_parallel_degree: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) + sequence_parallel_degree: number of GPUs used for sequence parallelism. + tensor_parallel_degree: number of GPUs used for tensor parallelism. + backend: distributed backend of pytorch collective comm. + """ + # Get world size and rank. Ensure some consistencies. + if not dist.is_initialized(): + logging.error("dist is not initialized") + world_size: int = dist.get_world_size() + backend = backend + + if ( + world_size + != classifier_free_guidance_degree + * sequence_parallel_degree + * tensor_parallel_degree + ): + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"sequence_parallel_degree ({sequence_parallel_degree}) x " + f"classifier_free_guidance_degree " + f"({classifier_free_guidance_degree}) x " + f"tensor_parallel_degree " + f"({tensor_parallel_degree})" + ) + + rank_generator: RankGenerator = RankGenerator( + tensor_parallel_degree, + sequence_parallel_degree, + classifier_free_guidance_degree, + "tp-sp-cfg", + ) + + global _CFG + if _CFG is not None: + logging.error("classifier_free_guidance group is already initialized") + _CFG = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + + global _SP + if _SP is not None: + logging.error("sequence parallel group is already initialized") + set_seq_parallel_pg( + sp_ulysses_degree=ulysses_degree, + sp_ring_degree=ring_degree, + rank=get_world_group().rank_in_group, + world_size=world_size + ) + _SP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ulysses_group=PROCESS_GROUP.ULYSSES_PG, + ring_group=PROCESS_GROUP.RING_PG, + ) + + global _TP + assert _TP is None, "Tensor parallel group is already initialized" + _TP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("tp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="tensor", + ) + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _CFG + if _CFG: + _CFG.destroy() + _CFG = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + + global _TP + if _TP: + _TP.destroy() + _TP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if dist.is_initialized(): + dist.destroy_process_group() + + +def init_parallel_env(parallel_config: ParallelConfig): + if not model_parallel_is_initialized(): + logging.warning("Model parallel is not initialized, initializing...") + init_distributed_environment( + world_size=dist.get_world_size(), + rank=dist.get_rank(), + backend='hccl', + ) + initialize_model_parallel( + classifier_free_guidance_degree=parallel_config.cfg_degree, + sequence_parallel_degree=parallel_config.sp_degree, + ulysses_degree=parallel_config.ulysses_degree, + ring_degree=parallel_config.ring_degree, + tensor_parallel_degree=parallel_config.tp_degree, + ) + + +def finalize_parallel_env(): + if model_parallel_is_initialized(): + destroy_model_parallel() + destroy_distributed_environment() \ No newline at end of file diff --git a/MindIE/MultiModal/Wan2.2/wan/distributed/sequence_parallel.py b/MindIE/MultiModal/Wan2.2/wan/distributed/sequence_parallel.py new file mode 100644 index 0000000000..fd0199097c --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/distributed/sequence_parallel.py @@ -0,0 +1,183 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.cuda.amp as amp +from .parallel_mgr import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, +) +from ..modules.attn_layer import xFuserLongContextAttention + +from ..modules.model import sinusoidal_embedding_1d +from mindiesd import rotary_position_embedding + + +def pad_freqs(original_tensor, target_len): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.ones( + pad_size, + s1, + s2, + dtype=torch.float32, + device=original_tensor.device + ).to(original_tensor.dtype) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + + +@torch.amp.autocast('cuda', enabled=False) +def rope_apply(x, grid_sizes, freqs_list): + s, n, c = x.size(1), x.size(2), x.size(3) + output = [] + + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + x_i = x[i, :s].reshape(1, s, n, c) + cos, sin = freqs_list[i] + x_i = rotary_position_embedding(x_i, cos, sin, rotated_mode="rotated_interleaved", fused=True) + output.append(x_i) + return torch.cat(output).float() + + +def sp_dit_forward( + self, + x, + t, + context, + seq_len, + y=None, +): + """ + x: A list of videos each with shape [C, T, H, W]. + t: [B]. + context: A list of text embeddings each with shape [L, C]. + """ + if self.model_type == 'i2v': + assert y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) + for u in x + ]) + + # time embeddings + if t.dim() == 1: + t = t.expand(t.size(0), seq_len) + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + bt = t.size(0) + t = t.flatten() + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, + t).unflatten(0, (bt, seq_len)).float()) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + # Context Parallel + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + e = torch.chunk(e, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + e0 = torch.chunk(e0, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + + + if self.freqs_list is None: + c = (self.dim // self.num_heads) // 2 + s = x.shape[1] + freqs = self.freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + freqs_list = [] + + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # apply rotary embedding + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs_i = pad_freqs(freqs_i, s * sp_size) + s_per_rank = s + freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] + cos, sin = torch.chunk(torch.view_as_real(freqs_i_rank.to(torch.complex64)), 2, dim=-1) + cos = cos.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2) + sin = sin.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2) + freqs_i_rank = (cos, sin) + freqs_list.append(freqs_i_rank) + self.freqs_list = freqs_list + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs_list, + context=context, + context_lens=context_lens) + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # Context Parallel + x = get_sp_group().all_gather(x, dim=1) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + +def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, args, dtype=torch.bfloat16): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + half_dtypes = (torch.float16, torch.bfloat16) + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + + x = xFuserLongContextAttention(args)( + None, + query=half(q), + key=half(k), + value=half(v), + seq_lens=seq_lens, + window_size=self.window_size, + ) + + # output + x = x.flatten(2) + x = self.o(x) + return x diff --git a/MindIE/MultiModal/Wan2.2/wan/distributed/tp_applicator.py b/MindIE/MultiModal/Wan2.2/wan/distributed/tp_applicator.py new file mode 100644 index 0000000000..eb65c823f0 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/distributed/tp_applicator.py @@ -0,0 +1,329 @@ +import torch +import torch.nn as nn +import torch_npu + +from ..modules.model import WanSelfAttention, WanAttentionBlock, WanRMSNorm +from .parallel_mgr import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) +from .group_coordinator import GroupCoordinator + + +class TensorParallelApplicator: + def __init__(self, tp_size, device_map="cpu", tp_group=None): + self.tp_size = tp_size + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_group = tp_group or get_tp_group() + self.device_map = device_map + + def apply_to_model(self, model): + self._apply_tp_to_attention(model) + self._apply_tp_to_ffn(model) + + def _apply_tp_to_attention(self, module): + for name, child in module.named_children(): + if isinstance(child, WanSelfAttention): + self._replace_self_attention(child) + else: + self._apply_tp_to_attention(child) + + def _replace_self_attention(self, child): + child.dim = child.dim // self.tp_size + child.num_heads = child.num_heads // self.tp_size + orig_q = child.q + orig_k = child.k + orig_v = child.v + orig_o = child.o + orig_dtype = orig_q.weight.dtype + + column_out = orig_q.out_features // self.tp_size + row_in = orig_o.in_features // self.tp_size + + child.q = ColumnParallelLinear( + orig_q.in_features, + column_out, + bias=orig_q.bias is not None, + gather_output=False, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + child.k = ColumnParallelLinear( + orig_k.in_features, + column_out, + bias=orig_k.bias is not None, + gather_output=False, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + child.v = ColumnParallelLinear( + orig_v.in_features, + column_out, + bias=orig_v.bias is not None, + gather_output=False, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + child.o = RowParallelLinear( + row_in, + orig_o.out_features, + bias=orig_o.bias is not None, + input_is_parallel=True, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + self._split_self_weights(child, orig_q, orig_k, orig_v, orig_o) + + if isinstance(child.norm_q, WanRMSNorm): + ori_norm_q = child.norm_q + child.norm_q = TensorParallelRMSNorm( + dim=child.norm_q.dim, + tp_size=self.tp_size, + tp_group=self.tp_group + ) + self._split_norm_weights(child.norm_q, ori_norm_q) + + if isinstance(child.norm_k, WanRMSNorm): + ori_norm_k = child.norm_k + child.norm_k = TensorParallelRMSNorm( + dim=child.norm_k.dim, + tp_size=self.tp_size, + tp_group=self.tp_group + ) + self._split_norm_weights(child.norm_k, ori_norm_k) + + + def _split_self_weights(self, new_layer, orig_q, orig_k, orig_v, orig_o): + q_chunk = torch.chunk(orig_q.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.q.weight.data = q_chunk.contiguous() + + k_chunk = torch.chunk(orig_k.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.k.weight.data = k_chunk.contiguous() + + v_chunk = torch.chunk(orig_v.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.v.weight.data = v_chunk.contiguous() + + o_chunk = torch.chunk(orig_o.weight.data, self.tp_size, dim=1)[self.tp_rank] + new_layer.o.weight.data = o_chunk.contiguous() + + if orig_q.bias is not None: + bias_chunk = torch.chunk(orig_q.bias.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.q.bias.data = bias_chunk.contiguous() + if orig_k.bias is not None: + bias_chunk = torch.chunk(orig_k.bias.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.k.bias.data = bias_chunk.contiguous() + if orig_v.bias is not None: + bias_chunk = torch.chunk(orig_v.bias.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.v.bias.data = bias_chunk.contiguous() + if orig_o.bias is not None: + new_layer.o.bias.data = orig_o.bias.data.clone() / self.tp_size + + def _split_norm_weights(self, new_layer, norm): + norm_chunk = torch.chunk(norm.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.weight.data = norm_chunk.contiguous() + + def _replace_cross_attention(self, child): + orig_wq = child.wq + orig_wkv = child.wkv + orig_wo = child.wo + orig_dtype = orig_wq.weight.dtype + + column_out_wq = orig_wq.out_features // self.tp_size + column_out_wkv = orig_wkv.out_features // self.tp_size + row_in_wo = orig_wo.in_features // self.tp_size + + child.wq = ColumnParallelLinear( + orig_wq.in_features, + column_out_wq, + bias=orig_wq.bias is not None, + gather_output=False, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + child.wkv = ColumnParallelLinear( + orig_wkv.in_features, + column_out_wkv, + bias=orig_wkv.bias is not None, + gather_output=False, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + child.wo = RowParallelLinear( + row_in_wo, + orig_wo.out_features, + bias=orig_wo.bias is not None, + input_is_parallel=True, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + self._split_cross_attention_weights(child, orig_wq, orig_wkv, orig_wo) + child.n_heads_per_tp = child.n_heads // self.tp_size + + def _split_cross_attention_weights(self, new_layer, orig_wq, orig_wkv, orig_wo): + wq_chunk = torch.chunk(orig_wq.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.wq.weight.data = wq_chunk.contiguous() + if orig_wq.bias is not None: + wq_bias_chunk = torch.chunk(orig_wq.bias.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.wq.bias.data = wq_bias_chunk.contiguous() + + wkv_chunk = torch.chunk(orig_wkv.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.wkv.weight.data = wkv_chunk.contiguous() + if orig_wkv.bias is not None: + wkv_bias_chunk = torch.chunk(orig_wkv.bias.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.wkv.bias.data = wkv_bias_chunk.contiguous() + + wo_chunk = torch.chunk(orig_wo.weight.data, self.tp_size, dim=1)[self.tp_rank] + new_layer.wo.weight.data = wo_chunk.contiguous() + if orig_wo.bias is not None: + new_layer.wo.bias.data = orig_wo.bias.data.clone() / self.tp_size + + def _apply_tp_to_ffn(self, module): + for name, child in module.named_children(): + if isinstance(child, WanAttentionBlock): + self._replace_ffn_layers(child) + else: + self._apply_tp_to_ffn(child) + + def _replace_ffn_layers(self, block): + ff_layer = block.ffn + orig_gelu_linear = ff_layer[0] + inner_dim_per_tp = orig_gelu_linear.out_features // self.tp_size + orig_dtype = orig_gelu_linear.weight.dtype + + ff_layer[0] = ColumnParallelLinear( + in_features=orig_gelu_linear.in_features, + out_features=inner_dim_per_tp, + bias=orig_gelu_linear.bias is not None, + gather_output=False, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + orig_output_linear = ff_layer[2] + ff_layer[2] = RowParallelLinear( + in_features=inner_dim_per_tp, + out_features=orig_output_linear.out_features, + bias=orig_output_linear.bias is not None, + input_is_parallel=True, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + self._split_ffn_weights(ff_layer, orig_gelu_linear, orig_output_linear) + + def _split_ffn_weights(self, new_ffn, orig_first_linear, orig_second_linear): + with torch.no_grad(): + first_weight_chunk = torch.chunk(orig_first_linear.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_ffn[0].weight.data.copy_(first_weight_chunk.contiguous()) + + if orig_first_linear.bias is not None: + first_bias_chunk = torch.chunk(orig_first_linear.bias.data, self.tp_size, dim=0)[self.tp_rank] + new_ffn[0].bias.data.copy_(first_bias_chunk.contiguous()) + + second_weight_chunk = torch.chunk(orig_second_linear.weight.data, self.tp_size, dim=1)[self.tp_rank] + new_ffn[2].weight.data.copy_(second_weight_chunk.contiguous()) + + if orig_second_linear.bias is not None: + new_ffn[2].bias.data.copy_(orig_second_linear.bias.data.clone() / self.tp_size) + + +class ColumnParallelLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True, gather_output=True, tp_size=None, tp_rank=None, tp_group=None): + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.tp_rank = tp_rank or get_tensor_model_parallel_rank() + self.tp_group = tp_group or get_tp_group() + + super().__init__(in_features, out_features, bias=bias) + + def forward(self, x): + x = super().forward(x) + return x + + +class RowParallelLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True, input_is_parallel=True, + tp_size=None, tp_rank=None, tp_group=None, matmul_allreduce_type="torch"): + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.tp_rank = tp_rank or get_tensor_model_parallel_rank() + self.tp_group = tp_group or get_tp_group() + self.input_is_parallel = input_is_parallel + + if matmul_allreduce_type == "atb": + try: + from atb_ops.ops.matmul_allreduce import matmul_allreduce + self.matmul_allreduce = matmul_allreduce + self.matmul_allreduce_type = "atb" + except Exception: + self.matmul_allreduce = None + self.matmul_allreduce_type = "torch" + else: + self.matmul_allreduce_type = matmul_allreduce_type + + super().__init__(in_features, out_features, bias=bias) + + def forward(self, x): + if not self.input_is_parallel: + x = torch.chunk(x, self.tp_size, dim=-1)[self.tp_rank] + + if self.matmul_allreduce_type == "atb": + if x.dim() == 2: + output = torch.empty((x.shape[0], self.weight.shape[0]), dtype=x.dtype, device=x.device) + elif x.dim() == 3: + b, s, hx = x.size() + output = torch.empty((b, s, self.weight.shape[0]), dtype=x.dtype, device=x.device) + self.matmul_allreduce(output, x, self.weight) + elif self.matmul_allreduce_type == "torch_npu": + if isinstance(self.tp_group, GroupCoordinator): + tp_pg = self.tp_group.device_group + else: + tp_pg = self.tp_group + hcom = tp_pg._get_backend(torch.device('npu')).get_hccl_comm_name + output = torch_npu.npu_mm_all_reduce_base(x, self.weight, hcom) + else: + x = super().forward(x) + # 执行All-Reduce聚合结果 + if isinstance(self.tp_group, GroupCoordinator): + output = self.tp_group.all_reduce(x) + else: + torch.distributed.all_reduce(x, group=self.tp_group) + output = x + return output + + +class TensorParallelRMSNorm(nn.Module): + def __init__(self, dim, tp_size, tp_group, eps=1e-6): + super().__init__() + self.tp_size = tp_size + self.tp_group = tp_group + self.variance_epsilon = eps + self.weight = nn.Parameter(torch.ones(dim // self.tp_size)) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + if isinstance(self.tp_group, GroupCoordinator): + variance = self.tp_group.all_reduce(variance) + else: + torch.distributed.all_reduce(variance, group=self.tp_group) + variance /= self.tp_size + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states.to(input_dtype) \ No newline at end of file diff --git a/MindIE/MultiModal/Wan2.2/wan/distributed/ulysses.py b/MindIE/MultiModal/Wan2.2/wan/distributed/ulysses.py new file mode 100644 index 0000000000..12d7d30a84 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/distributed/ulysses.py @@ -0,0 +1,47 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.distributed as dist + +from ..modules.attention import flash_attention +from .util import all_to_all + + +def distributed_attention( + q, + k, + v, + seq_lens, + window_size=(-1, -1), +): + """ + Performs distributed attention based on DeepSpeed Ulysses attention mechanism. + please refer to https://arxiv.org/pdf/2309.14509 + + Args: + q: [B, Lq // p, Nq, C1]. + k: [B, Lk // p, Nk, C1]. + v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk. + seq_lens: [B], length of each sequence in batch + window_size: (left right). If not (-1, -1), apply sliding window local attention. + """ + if not dist.is_initialized(): + raise ValueError("distributed group should be initialized.") + b = q.shape[0] + + # gather q/k/v sequence + q = all_to_all(q, scatter_dim=2, gather_dim=1) + k = all_to_all(k, scatter_dim=2, gather_dim=1) + v = all_to_all(v, scatter_dim=2, gather_dim=1) + + # apply attention + x = flash_attention( + q, + k, + v, + k_lens=seq_lens, + window_size=window_size, + ) + + # scatter q/k/v sequence + x = all_to_all(x, scatter_dim=1, gather_dim=2) + return x diff --git a/MindIE/MultiModal/Wan2.2/wan/distributed/util.py b/MindIE/MultiModal/Wan2.2/wan/distributed/util.py new file mode 100644 index 0000000000..f2eb680c8d --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/distributed/util.py @@ -0,0 +1,203 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from typing import List +import logging +import torch +import torch.distributed as dist + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: List[int], mask: List[bool] +) -> List[List[int]]: + """Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + if not ( + sum([x * y for x, y in zip(idx, stride[:-1])]) == index + ): + logging.error("idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator(object): + def __init__( + self, + tp: int, + sp: int, + cfg: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.tp = tp + self.sp = sp + self.cfg = cfg + self.rank_offset = rank_offset + self.world_size = tp * sp * cfg + + self.name_to_size = { + "sp": self.sp, + "cfg": self.cfg, + "tp": self.tp, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + + independent_ep (bool: True): + This flag controls whether we treat EP and DP independently. + EP shares ranks with DP, if we want to get ranks related to + EP, we should set the flag. For example, get_ranks('dp', True) + will get DP modulo EP group, and get_ranks('dp', False) will + get full DP group. + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups( + self.world_size, self.ordered_size, mask + ) + if self.rank_offset > 0: + for rank_group in ranks: + for i, _ in enumerate(rank_group): + rank_group[i] += self.rank_offset + return ranks + + +def init_distributed_group(): + """r initialize sequence parallel group. + """ + if not dist.is_initialized(): + dist.init_process_group(backend='nccl') + + +def get_rank(): + return dist.get_rank() + + +def get_world_size(): + return dist.get_world_size() + + +def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs): + """ + `scatter` along one dimension and `gather` along another. + """ + world_size = get_world_size() + if world_size > 1: + inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)] + outputs = [torch.empty_like(u) for u in inputs] + dist.all_to_all(outputs, inputs, group=group, **kwargs) + x = torch.cat(outputs, dim=gather_dim).contiguous() + return x + + +def all_gather(tensor): + world_size = dist.get_world_size() + if world_size == 1: + return [tensor] + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, tensor) + return tensor_list + + +def gather_forward(input, dim): + # skip if world_size == 1 + world_size = dist.get_world_size() + if world_size == 1: + return input + + # gather sequence + output = all_gather(input) + return torch.cat(output, dim=dim).contiguous() diff --git a/MindIE/MultiModal/Wan2.2/wan/image2video.py b/MindIE/MultiModal/Wan2.2/wan/image2video.py new file mode 100644 index 0000000000..1975216a53 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/image2video.py @@ -0,0 +1,464 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +import torchvision.transforms.functional as TF +from tqdm import tqdm + +from .distributed.fsdp import shard_model +from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward +from .distributed.util import get_world_size +from .modules.model import WanModel +from .modules.t5 import T5EncoderModel +from .modules.vae2_1 import Wan2_1_VAE +from .utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from .vae_patch_parallel import VAE_patch_parallel, set_vae_patch_parallel +from wan.distributed.parallel_mgr import ( + get_sequence_parallel_world_size, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group +) + + +class WanI2V: + + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_sp=False, + t5_cpu=False, + init_on_cpu=True, + convert_model_dtype=False, + use_vae_parallel=False, + ): + r""" + Initializes the image-to-video generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for T5 model + dit_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for DiT model + use_sp (`bool`, *optional*, defaults to False): + Enable distribution strategy of sequence parallel. + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + init_on_cpu (`bool`, *optional*, defaults to True): + Enable initializing Transformer Model on CPU. Only works without FSDP or USP. + convert_model_dtype (`bool`, *optional*, defaults to False): + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.init_on_cpu = init_on_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.boundary = config.boundary + self.param_dtype = config.param_dtype + + if t5_fsdp or dit_fsdp or use_sp: + self.init_on_cpu = False + + shard_fn = partial(shard_model, device_id=device_id) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu') if os.getenv('T5_LOAD_CPU', 0) else self.device, + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=shard_fn if t5_fsdp else None, + ) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = Wan2_1_VAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device, + dtype=self.param_dtype) + if use_vae_parallel: + all_pp_group_ranks = [] + for i in range(0, dist.get_world_size() // 8): + all_pp_group_ranks.append(list(range(8 * i, 8 * (i + 1)))) + set_vae_patch_parallel(self.vae.model, 4, 2, all_pp_group_ranks= all_pp_group_ranks, decoder_decode="decoder.forward") + set_vae_patch_parallel(self.vae.model, 4, 2, all_pp_group_ranks= all_pp_group_ranks, decoder_decode="encoder.forward") + logging.info(f"Creating WanModel from {checkpoint_dir}") + self.low_noise_model = WanModel.from_pretrained( + checkpoint_dir, subfolder=config.low_noise_checkpoint) + self.low_noise_model = self._configure_model( + model=self.low_noise_model, + use_sp=use_sp, + dit_fsdp=dit_fsdp, + shard_fn=shard_fn, + convert_model_dtype=convert_model_dtype) + + self.high_noise_model = WanModel.from_pretrained( + checkpoint_dir, subfolder=config.high_noise_checkpoint) + self.high_noise_model = self._configure_model( + model=self.high_noise_model, + use_sp=use_sp, + dit_fsdp=dit_fsdp, + shard_fn=shard_fn, + convert_model_dtype=convert_model_dtype) + if use_sp: + self.sp_size = get_sequence_parallel_world_size() + else: + self.sp_size = 1 + + self.sample_neg_prompt = config.sample_neg_prompt + + def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, + convert_model_dtype): + """ + Configures a model object. This includes setting evaluation modes, + applying distributed parallel strategy, and handling device placement. + + Args: + model (torch.nn.Module): + The model instance to configure. + use_sp (`bool`): + Enable distribution strategy of sequence parallel. + dit_fsdp (`bool`): + Enable FSDP sharding for DiT model. + shard_fn (callable): + The function to apply FSDP sharding. + convert_model_dtype (`bool`): + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + + Returns: + torch.nn.Module: + The configured model. + """ + model.eval().requires_grad_(False) + + if use_sp: + for block in model.blocks: + block.self_attn.forward = types.MethodType( + sp_attn_forward, block.self_attn) + model.forward = types.MethodType(sp_dit_forward, model) + + if dist.is_initialized(): + dist.barrier() + + if dit_fsdp: + model = shard_fn(model) + else: + if convert_model_dtype: + model.to(self.param_dtype) + if not self.init_on_cpu: + model.to(self.device) + + return model + + def _prepare_model_for_timestep(self, t, boundary, offload_model): + r""" + Prepares and returns the required model for the current timestep. + + Args: + t (torch.Tensor): + current timestep. + boundary (`int`): + The timestep threshold. If `t` is at or above this value, + the `high_noise_model` is considered as the required model. + offload_model (`bool`): + A flag intended to control the offloading behavior. + + Returns: + torch.nn.Module: + The active model on the target device for the current timestep. + """ + if t.item() >= boundary: + required_model_name = 'high_noise_model' + offload_model_name = 'low_noise_model' + else: + required_model_name = 'low_noise_model' + offload_model_name = 'high_noise_model' + if offload_model or self.init_on_cpu: + if next(getattr( + self, + offload_model_name).parameters()).device.type == 'cuda': + getattr(self, offload_model_name).to('cpu') + if next(getattr( + self, + required_model_name).parameters()).device.type == 'cpu': + getattr(self, required_model_name).to(self.device) + return getattr(self, required_model_name) + + def generate(self, + input_prompt, + img, + max_area=720 * 1280, + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=40, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from input image and text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation. + img (PIL.Image.Image): + Input image tensor. Shape: [3, H, W] + max_area (`int`, *optional*, defaults to 720*1280): + Maximum pixel area for latent space calculation. Controls video resolution scaling + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity. + If tuple, the first guide_scale will be used for low noise model and + the second guide_scale will be used for high noise model. + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from max_area) + - W: Frame width from max_area) + """ + # preprocess + guide_scale = (guide_scale, guide_scale) if isinstance( + guide_scale, float) else guide_scale + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) + + F = frame_num + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round( + np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // + self.patch_size[1] * self.patch_size[1]) + lat_w = round( + np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // + self.patch_size[2] * self.patch_size[2]) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + + max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( + self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + noise = torch.randn( + 16, + (F - 1) // self.vae_stride[0] + 1, + lat_h, + lat_w, + dtype=torch.float32, + generator=seed_g, + device=self.device) + + msk = torch.ones(1, F, lat_h, lat_w, device=self.device) + msk[:, 1:] = 0 + msk = torch.concat([ + torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] + ], + dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + + # preprocess + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + encode_input = torch.concat([ + torch.nn.functional.interpolate( + img[None].cpu(), size=(h, w), mode='bicubic' + ).transpose(0, 1), + torch.zeros(3, F - 1, h, w)], dim=1).to(self.device) + + with VAE_patch_parallel(): + y = self.vae.encode([ + encode_input + ])[0] + y = torch.concat([msk, y]) + + @contextmanager + def noop_no_sync(): + yield + + no_sync_low_noise = getattr(self.low_noise_model, 'no_sync', + noop_no_sync) + no_sync_high_noise = getattr(self.high_noise_model, 'no_sync', + noop_no_sync) + + # evaluation mode + with ( + torch.amp.autocast('cuda', dtype=self.param_dtype), + torch.no_grad(), + no_sync_low_noise(), + no_sync_high_noise(), + ): + boundary = self.boundary * self.num_train_timesteps + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latent = noise + + arg_c = { + 'context': [context[0]], + 'seq_len': max_seq_len, + 'y': [y], + } + + arg_null = { + 'context': context_null, + 'seq_len': max_seq_len, + 'y': [y], + } + + arg_all = { + 'context': [context[0]] if get_classifier_free_guidance_rank()==0 else context_null, + 'seq_len': max_seq_len, + 'y': [y], + } + + if offload_model: + torch.cuda.empty_cache() + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = [latent.to(self.device)] + timestep = [t] + + timestep = torch.stack(timestep).to(self.device) + + model = self._prepare_model_for_timestep( + t, boundary, offload_model) + sample_guide_scale = guide_scale[1] if t.item( + ) >= boundary else guide_scale[0] + + if get_classifier_free_guidance_world_size() == 2: + noise_pred = model( + latent_model_input, t=timestep, **arg_all)[0].to( + torch.device('cpu') if offload_model else self.device) + noise_pred_cond, noise_pred_uncond = get_cfg_group().all_gather( + noise_pred, separate_tensors=True + ) + if offload_model: + torch.cuda.empty_cache() + else: + noise_pred_cond = model( + latent_model_input, t=timestep, **arg_c)[0] + if offload_model: + torch.cuda.empty_cache() + noise_pred_uncond = model( + latent_model_input, t=timestep, **arg_null)[0] + if offload_model: + torch.cuda.empty_cache() + + noise_pred = noise_pred_uncond + sample_guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latent.unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latent = temp_x0.squeeze(0) + + x0 = [latent] + del latent_model_input, timestep + + if offload_model: + self.low_noise_model.cpu() + self.high_noise_model.cpu() + torch.cuda.empty_cache() + + if self.rank < 8: + with VAE_patch_parallel(): + videos = self.vae.decode(x0) + + del noise, latent, x0 + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None diff --git a/MindIE/MultiModal/Wan2.2/wan/modules/__init__.py b/MindIE/MultiModal/Wan2.2/wan/modules/__init__.py new file mode 100644 index 0000000000..9d9eeb8ebc --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/modules/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from .attention import flash_attention +from .model import WanModel +from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model +from .tokenizers import HuggingfaceTokenizer +from .vae2_1 import Wan2_1_VAE +from .vae2_2 import Wan2_2_VAE + +__all__ = [ + 'Wan2_1_VAE', + 'Wan2_2_VAE', + 'WanModel', + 'T5Model', + 'T5Encoder', + 'T5Decoder', + 'T5EncoderModel', + 'HuggingfaceTokenizer', + 'flash_attention', +] diff --git a/MindIE/MultiModal/Wan2.2/wan/modules/attention.py b/MindIE/MultiModal/Wan2.2/wan/modules/attention.py new file mode 100644 index 0000000000..55dc372d28 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/modules/attention.py @@ -0,0 +1,194 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import os +import torch + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +import warnings + +__all__ = [ + 'flash_attention', + 'attention', +] + +from mindiesd import attention_forward + + +def flash_attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + version=None, +): + """ + q: [B, Lq, Nq, C1]. + k: [B, Lk, Nk, C1]. + v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. + q_lens: [B]. + k_lens: [B]. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + causal: bool. Whether to apply causal attention mask. + window_size: (left right). If not (-1, -1), apply sliding window local attention. + deterministic: bool. If True, slightly slower and uses more memory. + dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. + """ + half_dtypes = (torch.float16, torch.bfloat16) + assert dtype in half_dtypes + assert q.device.type == 'cuda' and q.size(-1) <= 256 + + # params + b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # preprocess query + if q_lens is None: + q = half(q.flatten(0, 1)) + q_lens = torch.tensor( + [lq] * b, dtype=torch.int32).to( + device=q.device, non_blocking=True) + else: + q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) + + # preprocess key, value + if k_lens is None: + k = half(k.flatten(0, 1)) + v = half(v.flatten(0, 1)) + k_lens = torch.tensor( + [lk] * b, dtype=torch.int32).to( + device=k.device, non_blocking=True) + else: + k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) + v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + + q = q.to(v.dtype) + k = k.to(v.dtype) + + if q_scale is not None: + q = q * q_scale + + if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: + warnings.warn( + 'Flash attention 3 is not available, use flash attention 2 instead.' + ) + + # apply attention + if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: + # Note: dropout_p, window_size are not supported in FA3 now. + x = flash_attn_interface.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + seqused_q=None, + seqused_k=None, + max_seqlen_q=lq, + max_seqlen_k=lk, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic)[0].unflatten(0, (b, lq)) + else: + assert FLASH_ATTN_2_AVAILABLE + x = flash_attn.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + max_seqlen_q=lq, + max_seqlen_k=lk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic).unflatten(0, (b, lq)) + + # output + return x.type(out_dtype) + + +def attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + version=None +): + if torch.npu.is_available(): + qtype = q.dtype + q = q.to(torch.bfloat16) + k = k.to(torch.bfloat16) + v = v.to(torch.bfloat16) + if version is None and q.shape[1] == k.shape[1] and int(os.getenv('ALGO', 0)) == 1: + out = attention_forward(q, k, v, + opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD") + else: + out = attention_forward(q, k, v, + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + return out.to(qtype) + elif FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: + return flash_attention( + q=q, + k=k, + v=v, + q_lens=q_lens, + k_lens=k_lens, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + q_scale=q_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic, + dtype=dtype, + version=version, + ) + else: + if q_lens is not None or k_lens is not None: + warnings.warn( + 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' + ) + attn_mask = None + + q = q.transpose(1, 2).to(dtype) + k = k.transpose(1, 2).to(dtype) + v = v.transpose(1, 2).to(dtype) + + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) + + out = out.transpose(1, 2).contiguous() + return out diff --git a/MindIE/MultiModal/Wan2.2/wan/modules/attn_layer.py b/MindIE/MultiModal/Wan2.2/wan/modules/attn_layer.py new file mode 100644 index 0000000000..df245a1d24 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/modules/attn_layer.py @@ -0,0 +1,178 @@ +import logging +import torch +from torch import Tensor +import torch_npu +import torch.distributed as dist +import math +import os +from yunchang import LongContextAttention +try: + from yunchang.kernels import AttnType +except ImportError: + raise ImportError("Please install yunchang 0.6.0 or later") +from typing import Any + +from mindiesd.layers.flash_attn.attention_forward import attention_forward + +from ..distributed.parallel_mgr import get_sp_group +from ..distributed.comm import all_to_all_4D + +logger = logging.getLogger(__name__) +MAX_TOKEN = 2147483647 + +class xFuserLongContextAttention(LongContextAttention): + ring_impl_type_supported_kv_cache = ["basic"] + + def __init__( + self, + args: Any, + scatter_idx: int = 2, + gather_idx: int = 1, + ring_impl_type: str = "basic", + use_pack_qkv: bool = False, + use_kv_cache: bool = False, + attn_type: AttnType = AttnType.FA, + ) -> None: + """ + Arguments: + scatter_idx: int = 2, the scatter dimension index for Ulysses All2All + gather_idx: int = 1, the gather dimension index for Ulysses All2All + ring_impl_type: str = "basic", the ring implementation type, currently only support "basic" + use_pack_qkv: bool = False, whether to use pack qkv in the input + use_kv_cache: bool = False, whether to use kv cache in the attention layer, which is applied in PipeFusion. + """ + super().__init__( + scatter_idx=scatter_idx, + gather_idx=gather_idx, + ring_impl_type=ring_impl_type, + use_pack_qkv=use_pack_qkv, + attn_type = attn_type, + ) + self.use_kv_cache = use_kv_cache + if ( + use_kv_cache + and ring_impl_type not in self.ring_impl_type_supported_kv_cache + ): + raise RuntimeError( + f"ring_impl_type: {ring_impl_type} do not support SP kv cache." + ) + self.world_size = dist.get_world_size() + self.args = args + self.video_size = ['480*832', '832*480', '480*720', '720*480'] + + self.algo = int(os.getenv('ALGO', 0)) + + if self.args.size in self.video_size: + self.use_all_head = True + else: + self.use_all_head = False + + self.ulysses_pg = get_sp_group().ulysses_group + self.ring_pg = get_sp_group().ring_group + + def forward( + self, + attn, + query: Tensor, + key: Tensor, + value: Tensor, + seq_lens: int, + *, + joint_tensor_query=None, + joint_tensor_key=None, + joint_tensor_value=None, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + joint_strategy="none", + scale=None + ) -> Tensor: + """forward + + Arguments: + attn (Attention): the attention module + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args, + joint_tensor_query: Tensor = None, a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy + joint_tensor_key: Tensor = None, a replicated tensor among processes appended to the front or rear of key, depends the joint_strategy + joint_tensor_value: Tensor = None, a replicated tensor among processes appended to the front or rear of value, depends the joint_strategy, + *args: the args same as flash_attn_interface + joint_strategy: str = "none", the joint strategy for joint attention, currently only support "front" and "rear" + + Returns: + * output (Tensor): context output + """ + + query = all_to_all_4D(input_=query, scatter_idx=2, gather_idx=1, group=self.ulysses_pg) + key = all_to_all_4D(input_=key, scatter_idx=2, gather_idx=1, group=self.ulysses_pg) + value = all_to_all_4D(input_=value, scatter_idx=2, gather_idx=1, group=self.ulysses_pg) + + if get_sp_group().ring_world_size > 1: + ring_size = get_sp_group().ring_world_size + b, s, n, d = key.shape + k_full = torch.empty([ring_size, b, s, n, d], dtype=query.dtype, device=query.device) + dist.all_gather_into_tensor(k_full, key, group=self.ring_pg) + key = k_full.permute(1, 0, 2, 3, 4).reshape(b, -1, n, d) + + v_full = torch.empty([ring_size, b, s, n, d], dtype=query.dtype, device=query.device) + dist.all_gather_into_tensor(v_full, value, group=self.ring_pg) + value = v_full.permute(1, 0, 2, 3, 4).reshape(b, -1, n, d) + + ori_seqlen = query.shape[1] + if seq_lens is not None and seq_lens < ori_seqlen: + query_layer, query_pad = query[:, :seq_lens, :, :], query[:, seq_lens:, :, :] + key_layer, key_pad = key[:, :seq_lens, :, :], key[:, seq_lens:, :, :] + value_layer, value_pad = value[:, :seq_lens, :, :], value[:, seq_lens:, :, :] + else: + query_layer, key_layer, value_layer = query, key, value + + if self.use_all_head: + if self.algo == 0: + out = attention_forward(query_layer, key_layer, value_layer, + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + elif self.algo == 1: + out = attention_forward(query_layer, key_layer, value_layer, + opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD") + else: + raise ValueError(f"select flash attention algorithm only support 0, 1, but got {self.algo}") + else: + query_layer_list = query_layer.split(1, dim=2) + key_layer_list = key_layer.split(1, dim=2) + value_layer_list = value_layer.split(1, dim=2) + output = [] + for_loop = query_layer.shape[2] + for i in range(for_loop): + if self.algo == 0: + out = attention_forward(query_layer_list[i], key_layer_list[i], value_layer_list[i], + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + elif self.algo == 1: + out = attention_forward(query_layer_list[i], key_layer_list[i], value_layer_list[i], + opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD") + else: + raise ValueError(f"select flash attention algorithm only support 0, 1, but got f{self.algo}") + + output.append(out) + out = torch.cat(output, dim=2) + + if seq_lens is not None and seq_lens < ori_seqlen: + out_pad = attention_forward(query_pad, key_pad, value_pad, + opt_mode="manual", op_type="fused_attn_score", layout="BSND") + out = torch.cat([out, out_pad], dim=1) + + if type(out) == tuple: + context_layer, _, _ = out + else: + context_layer = out + + # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) + # scatter 1, gather 2 + output = all_to_all_4D(input_=context_layer, scatter_idx=1, gather_idx=2, group=self.ulysses_pg) + + return output + diff --git a/MindIE/MultiModal/Wan2.2/wan/modules/model.py b/MindIE/MultiModal/Wan2.2/wan/modules/model.py new file mode 100644 index 0000000000..582a22099f --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/modules/model.py @@ -0,0 +1,582 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math + +import torch +import torch_npu +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin + +from .attention import flash_attention, attention + +from mindiesd import rotary_position_embedding + +__all__ = ['WanModel'] + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float32) + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@torch.amp.autocast('cuda', enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) + return freqs + + +@torch.amp.autocast('cuda', enabled=False) +def rope_apply(x, grid_sizes, freqs_list): + s, n, c = x.size(1), x.size(2), x.size(3) + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + x_i = x[i, :s].reshape(1, s, n, c) + cos, sin = freqs_list[i] + x_i = rotary_position_embedding(x_i, cos, sin, rotated_mode="rotated_interleaved", fused=True) + output.append(x_i) + return torch.cat(output).float() + + +class WanRMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0] + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + self.dim = dim + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return torch_npu.npu_layer_norm_eval( + x, normalized_shape=[self.dim], weight=self.weight, bias=self.bias, eps=self.eps + ) + + +class WanLayerNormModulate(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + self.dim = dim + + def forward(self, x, weight, scale): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return torch_npu.npu_layer_norm_eval( + x, normalized_shape=[self.dim], weight=weight, bias=scale, eps=self.eps, + ) + + +class WanSelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, seq_lens, grid_sizes, freqs, args=None): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + x = attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanCrossAttention(WanSelfAttention): + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanAttentionBlock(nn.Module): + + def __init__(self, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + # self.norm1 = WanLayerNormModulate(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, + eps) + self.norm3 = WanLayerNorm( + dim, eps, + elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, + eps) + self.norm2 = WanLayerNorm(dim, eps) + # self.norm2 = WanLayerNormModulate(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + # Attention Cache + self.cache = None + self.args = None + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, L1, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + # assert e.dtype == torch.float32 + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2) + # assert e[0].dtype == torch.float32 + + y = self.cache.apply( + self.self_attn, + self.norm1(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2), + seq_lens, + grid_sizes, + freqs, + self.args) + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + x = x + y * e[2].squeeze(2) + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + x = x + self.cross_attn(self.norm3(x), context, context_lens) + y = self.ffn( + self.norm2(x) * (1 + e[4].squeeze(2)) + e[3].squeeze(2)) + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + x = x + y * e[5].squeeze(2) + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, L1, C] + """ + # assert e.dtype == torch.float32 + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2) + x = ( + self.head( + self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2))) + return x + + +class WanModel(ModelMixin, ConfigMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = [ + 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' + ] + _no_split_modules = ['WanAttentionBlock'] + + @register_to_config + def __init__(self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ['t2v', 'i2v', 'ti2v'] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), + nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + self.blocks = nn.ModuleList([ + WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm, + cross_attn_norm, eps) for _ in range(num_layers) + ]) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat([ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1) + + # initialize weights + self.init_weights() + + self.freqs_list = None + + def forward( + self, + x, + t, + context, + seq_len, + y=None, + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + if self.model_type == 'i2v': + assert y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + if t.dim() == 1: + t = t.expand(t.size(0), seq_len) + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + bt = t.size(0) + t = t.flatten() + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, + t).unflatten(0, (bt, seq_len)).float()) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + # assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if self.freqs_list is None: + c = (self.dim // self.num_heads) // 2 + s = x.shape[1] + freqs = self.freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + freqs_list = [] + + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + cos, sin = torch.chunk(torch.view_as_real(freqs_i.to(torch.complex64)), 2, dim=-1) + cos = cos.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2) + sin = sin.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2) + freqs_i = (cos, sin) + freqs_list.append(freqs_i) + self.freqs_list = freqs_list + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs_list, + context=context, + context_lens=context_lens) + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) diff --git a/MindIE/MultiModal/Wan2.2/wan/modules/t5.py b/MindIE/MultiModal/Wan2.2/wan/modules/t5.py new file mode 100644 index 0000000000..c841b044a2 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/modules/t5.py @@ -0,0 +1,513 @@ +# Modified from transformers.models.t5.modeling_t5 +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .tokenizers import HuggingfaceTokenizer + +__all__ = [ + 'T5Model', + 'T5Encoder', + 'T5Decoder', + 'T5EncoderModel', +] + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5Model): + nn.init.normal_(m.token_embedding.weight, std=1.0) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_( + m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) + + +class GELU(nn.Module): + + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, + -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum('bnij,bjnc->binc', attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5CrossAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=False) + + def forward(self, + x, + mask=None, + encoder_states=None, + encoder_mask=None, + pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn( + self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ + torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( + 0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + math.log(self.max_dist / max_exact) * + (num_buckets - max_exact)).long() + rel_pos_large = torch.min( + rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + + +class T5Encoder(nn.Module): + + def __init__(self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Encoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Decoder(nn.Module): + + def __init__(self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Decoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=False) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): + b, s = ids.size() + + # causal mask + if mask is None: + mask = torch.tril(torch.ones(1, s, s).to(ids.device)) + elif mask.ndim == 2: + mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) + + # layers + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Model(nn.Module): + + def __init__(self, + vocab_size, + dim, + dim_attn, + dim_ffn, + num_heads, + encoder_layers, + decoder_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Model, self).__init__() + self.vocab_size = vocab_size + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_buckets = num_buckets + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn, + num_heads, encoder_layers, num_buckets, + shared_pos, dropout) + self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn, + num_heads, decoder_layers, num_buckets, + shared_pos, dropout) + self.head = nn.Linear(dim, vocab_size, bias=False) + + # initialize weights + self.apply(init_weights) + + def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): + x = self.encoder(encoder_ids, encoder_mask) + x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) + x = self.head(x) + return x + + +def _t5(name, + encoder_only=False, + decoder_only=False, + return_tokenizer=False, + tokenizer_kwargs={}, + dtype=torch.float32, + device='cpu', + **kwargs): + # sanity check + assert not (encoder_only and decoder_only) + + # params + if encoder_only: + model_cls = T5Encoder + kwargs['vocab'] = kwargs.pop('vocab_size') + kwargs['num_layers'] = kwargs.pop('encoder_layers') + _ = kwargs.pop('decoder_layers') + elif decoder_only: + model_cls = T5Decoder + kwargs['vocab'] = kwargs.pop('vocab_size') + kwargs['num_layers'] = kwargs.pop('decoder_layers') + _ = kwargs.pop('encoder_layers') + else: + model_cls = T5Model + + # init model + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + + # init tokenizer + if return_tokenizer: + from .tokenizers import HuggingfaceTokenizer + tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs) + return model, tokenizer + else: + return model + + +def umt5_xxl(**kwargs): + cfg = dict( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + encoder_layers=24, + decoder_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1) + cfg.update(**kwargs) + return _t5('umt5-xxl', **cfg) + + +class T5EncoderModel: + + def __init__( + self, + text_len, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + checkpoint_path=None, + tokenizer_path=None, + shard_fn=None, + ): + self.text_len = text_len + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + model = umt5_xxl( + encoder_only=True, + return_tokenizer=False, + dtype=dtype, + device=device).eval().requires_grad_(False) + logging.info(f'loading {checkpoint_path}') + model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + self.model = model + if shard_fn is not None: + self.model = shard_fn(self.model, sync_module_states=False) + else: + self.model.to(self.device) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer( + name=tokenizer_path, seq_len=text_len, clean='whitespace') + + def __call__(self, texts, device): + ids, mask = self.tokenizer( + texts, return_mask=True, add_special_tokens=True) + ids = ids.to(device) + mask = mask.to(device) + seq_lens = mask.gt(0).sum(dim=1).long() + context = self.model(ids, mask) + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/MindIE/MultiModal/Wan2.2/wan/modules/tokenizers.py b/MindIE/MultiModal/Wan2.2/wan/modules/tokenizers.py new file mode 100644 index 0000000000..121e591c48 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/modules/tokenizers.py @@ -0,0 +1,82 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + +__all__ = ['HuggingfaceTokenizer'] + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace('_', ' ') + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans('', '', string.punctuation)) + for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans('', '', string.punctuation)) + text = text.lower() + text = re.sub(r'\s+', ' ', text) + return text.strip() + + +class HuggingfaceTokenizer: + + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, 'whitespace', 'lower', 'canonicalize') + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop('return_mask', False) + + # arguments + _kwargs = {'return_tensors': 'pt'} + if self.seq_len is not None: + _kwargs.update({ + 'padding': 'max_length', + 'truncation': True, + 'max_length': self.seq_len + }) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == 'whitespace': + text = whitespace_clean(basic_clean(text)) + elif self.clean == 'lower': + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == 'canonicalize': + text = canonicalize(basic_clean(text)) + return text diff --git a/MindIE/MultiModal/Wan2.2/wan/modules/vae2_1.py b/MindIE/MultiModal/Wan2.2/wan/modules/vae2_1.py new file mode 100644 index 0000000000..6b3df33515 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/modules/vae2_1.py @@ -0,0 +1,663 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = [ + 'Wan2_1_VAE', +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (0, 0, 0, + 0, 2 * self.padding[0], 0) + self.padding = (0, self.padding[1], self.padding[2]) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -1:, :, :].clone() + # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': + # # cache last frame of last two chunk + # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 + conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk( + 3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + ## 对encode输入的x,按时间拆分为1、4、4、4.... + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + #cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + # params + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0) + cfg.update(**kwargs) + + # init model + with torch.device('meta'): + model = WanVAE_(**cfg) + + # load checkpoint + logging.info(f'loading {pretrained_path}') + model.load_state_dict( + torch.load(pretrained_path, map_location=device), assign=True) + + return model + + +class Wan2_1_VAE: + + def __init__(self, + z_dim=16, + vae_pth='cache/vae_step_411000.pth', + dtype=torch.float, + device="cuda"): + self.dtype = dtype + self.device = device + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean, dtype=dtype, device=device) + self.std = torch.tensor(std, dtype=dtype, device=device) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae( + pretrained_path=vae_pth, + z_dim=z_dim, + ).eval().requires_grad_(False).to(device) + + def encode(self, videos): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + with amp.autocast(dtype=self.dtype): + return [ + self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) + for u in videos + ] + + def decode(self, zs): + with amp.autocast(dtype=self.dtype): + return [ + self.model.decode(u.unsqueeze(0), + self.scale).float().clamp_(-1, 1).squeeze(0) + for u in zs + ] diff --git a/MindIE/MultiModal/Wan2.2/wan/modules/vae2_2.py b/MindIE/MultiModal/Wan2.2/wan/modules/vae2_2.py new file mode 100644 index 0000000000..c0b3f29bf4 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/modules/vae2_2.py @@ -0,0 +1,1051 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = [ + "Wan2_2_VAE", +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return (F.normalize(x, dim=(1 if self.channel_first else -1)) * + self.scale * self.gamma + self.bias) + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + # nn.Conv2d(dim, dim//2, 3, padding=1) + ) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and + feat_cache[idx] != "Rep"): + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and + feat_cache[idx] == "Rep"): + cache_x = torch.cat( + [ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2, + ) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight.detach().clone() + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 + conv.weight = nn.Parameter(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data.detach().clone() + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight = nn.Parameter(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = ( + CausalConv3d(in_dim, out_dim, 1) + if in_dim != out_dim else nn.Identity()) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm(x) + # compute query, key, value + q, k, v = ( + self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk(3, dim=-1)) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) + return x + identity + + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + return x + + +class AvgDown3D(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1:, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + mult, + temperal_downsample=False, + down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + mult, + temperal_upsample=False, + up_flag=False): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + +class Encoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = ( + temperal_downsample[i] + if i < len(temperal_downsample) else False) + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + )) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + +class Decoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len( + temperal_upsample) else False + upsamples.append( + Up_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1, + )) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + + def __init__( + self, + dim=160, + dec_dim=256, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d( + dim, + z_dim * 2, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_downsample, + dropout, + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d( + dec_dim, + z_dim, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_upsample, + dropout, + ) + + def forward(self, x, scale=[0, 1]): + mu = self.encode(x, scale) + x_recon = self.decode(mu, scale) + return x_recon, mu + + def encode(self, x, scale): + self.clear_cache() + x = patchify(x, patch_size=2) + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale): + self.clear_cache() + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk=True, + ) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) + out = torch.cat([out, out_], 2) + out = unpatchify(out, patch_size=2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs): + # params + cfg = dict( + dim=dim, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, True], + dropout=0.0, + ) + cfg.update(**kwargs) + + # init model + with torch.device("meta"): + model = WanVAE_(**cfg) + + # load checkpoint + logging.info(f"loading {pretrained_path}") + model.load_state_dict( + torch.load(pretrained_path, map_location=device), assign=True) + + return model + + +class Wan2_2_VAE: + + def __init__( + self, + z_dim=48, + c_dim=160, + vae_pth=None, + dim_mult=[1, 2, 4, 4], + temperal_downsample=[False, True, True], + dtype=torch.float, + device="cuda", + ): + + self.dtype = dtype + self.device = device + + mean = torch.tensor( + [ + -0.2289, + -0.0052, + -0.1323, + -0.2339, + -0.2799, + 0.0174, + 0.1838, + 0.1557, + -0.1382, + 0.0542, + 0.2813, + 0.0891, + 0.1570, + -0.0098, + 0.0375, + -0.1825, + -0.2246, + -0.1207, + -0.0698, + 0.5109, + 0.2665, + -0.2108, + -0.2158, + 0.2502, + -0.2055, + -0.0322, + 0.1109, + 0.1567, + -0.0729, + 0.0899, + -0.2799, + -0.1230, + -0.0313, + -0.1649, + 0.0117, + 0.0723, + -0.2839, + -0.2083, + -0.0520, + 0.3748, + 0.0152, + 0.1957, + 0.1433, + -0.2944, + 0.3573, + -0.0548, + -0.1681, + -0.0667, + ], + dtype=dtype, + device=device, + ) + std = torch.tensor( + [ + 0.4765, + 1.0364, + 0.4514, + 1.1677, + 0.5313, + 0.4990, + 0.4818, + 0.5013, + 0.8158, + 1.0344, + 0.5894, + 1.0901, + 0.6885, + 0.6165, + 0.8454, + 0.4978, + 0.5759, + 0.3523, + 0.7135, + 0.6804, + 0.5833, + 1.4146, + 0.8986, + 0.5659, + 0.7069, + 0.5338, + 0.4889, + 0.4917, + 0.4069, + 0.4999, + 0.6866, + 0.4093, + 0.5709, + 0.6065, + 0.6415, + 0.4944, + 0.5726, + 1.2042, + 0.5458, + 1.6887, + 0.3971, + 1.0600, + 0.3943, + 0.5537, + 0.5444, + 0.4089, + 0.7468, + 0.7744, + ], + dtype=dtype, + device=device, + ) + self.scale = [mean, 1.0 / std] + + # init model + self.model = ( + _video_vae( + pretrained_path=vae_pth, + z_dim=z_dim, + dim=c_dim, + dim_mult=dim_mult, + temperal_downsample=temperal_downsample, + ).eval().requires_grad_(False).to(device)) + + def encode(self, videos): + try: + if not isinstance(videos, list): + raise TypeError("videos should be a list") + with amp.autocast(dtype=self.dtype): + return [ + self.model.encode(u.unsqueeze(0), + self.scale).float().squeeze(0) + for u in videos + ] + except TypeError as e: + logging.info(e) + return None + + def decode(self, zs): + try: + if not isinstance(zs, list): + raise TypeError("zs should be a list") + with amp.autocast(dtype=self.dtype): + return [ + self.model.decode(u.unsqueeze(0), + self.scale).float().clamp_(-1, + 1).squeeze(0) + for u in zs + ] + except TypeError as e: + logging.info(e) + return None diff --git a/MindIE/MultiModal/Wan2.2/wan/text2video.py b/MindIE/MultiModal/Wan2.2/wan/text2video.py new file mode 100644 index 0000000000..635d1a1d0e --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/text2video.py @@ -0,0 +1,403 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial + +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +from tqdm import tqdm + +from .distributed.fsdp import shard_model +from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward +from .distributed.util import get_world_size +from .modules.model import WanModel +from .modules.t5 import T5EncoderModel +from .modules.vae2_1 import Wan2_1_VAE +from .utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from .vae_patch_parallel import VAE_patch_parallel, set_vae_patch_parallel +from wan.distributed.parallel_mgr import ( + get_sequence_parallel_world_size, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group, +) + +class WanT2V: + + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_sp=False, + t5_cpu=False, + init_on_cpu=True, + convert_model_dtype=False, + use_vae_parallel=False, + ): + r""" + Initializes the Wan text-to-video generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for T5 model + dit_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for DiT model + use_sp (`bool`, *optional*, defaults to False): + Enable distribution strategy of sequence parallel. + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + init_on_cpu (`bool`, *optional*, defaults to True): + Enable initializing Transformer Model on CPU. Only works without FSDP or USP. + convert_model_dtype (`bool`, *optional*, defaults to False): + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.init_on_cpu = init_on_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.boundary = config.boundary + self.param_dtype = config.param_dtype + + if t5_fsdp or dit_fsdp or use_sp: + self.init_on_cpu = False + + shard_fn = partial(shard_model, device_id=device_id) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu') if os.getenv('T5_LOAD_CPU', 0) else self.device, + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=shard_fn if t5_fsdp else None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = Wan2_1_VAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device, + dtype=self.param_dtype) + if use_vae_parallel: + all_pp_group_ranks = [] + for i in range(0, dist.get_world_size() // 8): + all_pp_group_ranks.append(list(range(8 * i, 8 * (i + 1)))) + set_vae_patch_parallel(self.vae.model, 4, 2, all_pp_group_ranks= all_pp_group_ranks, decoder_decode="decoder.forward") + + logging.info(f"Creating WanModel from {checkpoint_dir}") + self.low_noise_model = WanModel.from_pretrained( + checkpoint_dir, subfolder=config.low_noise_checkpoint) + self.low_noise_model = self._configure_model( + model=self.low_noise_model, + use_sp=use_sp, + dit_fsdp=dit_fsdp, + shard_fn=shard_fn, + convert_model_dtype=convert_model_dtype) + + self.high_noise_model = WanModel.from_pretrained( + checkpoint_dir, subfolder=config.high_noise_checkpoint) + self.high_noise_model = self._configure_model( + model=self.high_noise_model, + use_sp=use_sp, + dit_fsdp=dit_fsdp, + shard_fn=shard_fn, + convert_model_dtype=convert_model_dtype) + if use_sp: + self.sp_size = get_sequence_parallel_world_size() + else: + self.sp_size = 1 + + self.sample_neg_prompt = config.sample_neg_prompt + + def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, + convert_model_dtype): + """ + Configures a model object. This includes setting evaluation modes, + applying distributed parallel strategy, and handling device placement. + + Args: + model (torch.nn.Module): + The model instance to configure. + use_sp (`bool`): + Enable distribution strategy of sequence parallel. + dit_fsdp (`bool`): + Enable FSDP sharding for DiT model. + shard_fn (callable): + The function to apply FSDP sharding. + convert_model_dtype (`bool`): + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + + Returns: + torch.nn.Module: + The configured model. + """ + model.eval().requires_grad_(False) + + if use_sp: + for block in model.blocks: + block.self_attn.forward = types.MethodType( + sp_attn_forward, block.self_attn) + model.forward = types.MethodType(sp_dit_forward, model) + + if dist.is_initialized(): + dist.barrier() + + if dit_fsdp: + model = shard_fn(model) + else: + if convert_model_dtype: + model.to(self.param_dtype) + if not self.init_on_cpu: + model.to(self.device) + + return model + + def _prepare_model_for_timestep(self, t, boundary, offload_model): + r""" + Prepares and returns the required model for the current timestep. + + Args: + t (torch.Tensor): + current timestep. + boundary (`int`): + The timestep threshold. If `t` is at or above this value, + the `high_noise_model` is considered as the required model. + offload_model (`bool`): + A flag intended to control the offloading behavior. + + Returns: + torch.nn.Module: + The active model on the target device for the current timestep. + """ + if t.item() >= boundary: + required_model_name = 'high_noise_model' + offload_model_name = 'low_noise_model' + else: + required_model_name = 'low_noise_model' + offload_model_name = 'high_noise_model' + if offload_model or self.init_on_cpu: + if next(getattr( + self, + offload_model_name).parameters()).device.type == 'cuda': + getattr(self, offload_model_name).to('cpu') + if next(getattr( + self, + required_model_name).parameters()).device.type == 'cpu': + getattr(self, required_model_name).to(self.device) + return getattr(self, required_model_name) + + def generate(self, + input_prompt, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + size (`tuple[int]`, *optional*, defaults to (1280,720)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 50): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity. + If tuple, the first guide_scale will be used for low noise model and + the second guide_scale will be used for high noise model. + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # preprocess + guide_scale = (guide_scale, guide_scale) if isinstance( + guide_scale, float) else guide_scale + F = frame_num + target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2]) + + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + noise = [ + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ] + + @contextmanager + def noop_no_sync(): + yield + + no_sync_low_noise = getattr(self.low_noise_model, 'no_sync', + noop_no_sync) + no_sync_high_noise = getattr(self.high_noise_model, 'no_sync', + noop_no_sync) + + # evaluation mode + with ( + torch.amp.autocast('cuda', dtype=self.param_dtype), + torch.no_grad(), + no_sync_low_noise(), + no_sync_high_noise(), + ): + boundary = self.boundary * self.num_train_timesteps + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + + arg_c = {'context': context, 'seq_len': seq_len} + arg_null = {'context': context_null, 'seq_len': seq_len} + arg_all = { + 'context': context if get_classifier_free_guidance_rank()==0 else context_null, + 'seq_len': seq_len + } + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = torch.stack(timestep) + + model = self._prepare_model_for_timestep( + t, boundary, offload_model) + sample_guide_scale = guide_scale[1] if t.item( + ) >= boundary else guide_scale[0] + + if get_classifier_free_guidance_world_size() == 2: + noise_pred = model( + latent_model_input, t=timestep, **arg_all)[0] + noise_pred_cond, noise_pred_uncond = get_cfg_group().all_gather( + noise_pred, separate_tensors=True + ) + else: + noise_pred_cond = model( + latent_model_input, t=timestep, **arg_c)[0] + noise_pred_uncond = model( + latent_model_input, t=timestep, **arg_null)[0] + + noise_pred = noise_pred_uncond + sample_guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + + x0 = latents + if offload_model: + self.low_noise_model.cpu() + self.high_noise_model.cpu() + torch.cuda.empty_cache() + if self.rank < 8: + with VAE_patch_parallel(): + videos = self.vae.decode(x0) + + del noise, latents + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None diff --git a/MindIE/MultiModal/Wan2.2/wan/textimage2video.py b/MindIE/MultiModal/Wan2.2/wan/textimage2video.py new file mode 100644 index 0000000000..500e1802fd --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/textimage2video.py @@ -0,0 +1,648 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial + +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +import torchvision.transforms.functional as TF +from PIL import Image +from tqdm import tqdm + +from .distributed.fsdp import shard_model +from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward +from .distributed.util import get_world_size +from .modules.model import WanModel +from .modules.t5 import T5EncoderModel +from .modules.vae2_2 import Wan2_2_VAE +from .utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from .utils.utils import best_output_size, masks_like +from wan.distributed.parallel_mgr import ( + get_sequence_parallel_world_size, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group +) + +class WanTI2V: + + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_sp=False, + t5_cpu=False, + init_on_cpu=True, + convert_model_dtype=False + ): + r""" + Initializes the Wan text-to-video generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for T5 model + dit_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for DiT model + use_sp (`bool`, *optional*, defaults to False): + Enable distribution strategy of sequence parallel. + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + init_on_cpu (`bool`, *optional*, defaults to True): + Enable initializing Transformer Model on CPU. Only works without FSDP or USP. + convert_model_dtype (`bool`, *optional*, defaults to False): + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.init_on_cpu = init_on_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + if t5_fsdp or dit_fsdp or use_sp: + self.init_on_cpu = False + + shard_fn = partial(shard_model, device_id=device_id) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu') if os.getenv('T5_LOAD_CPU', 0) else self.device, + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=shard_fn if t5_fsdp else None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = Wan2_2_VAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device, + dtype=self.param_dtype) + + logging.info(f"Creating WanModel from {checkpoint_dir}") + self.model = WanModel.from_pretrained(checkpoint_dir) + + self.model = self._configure_model( + model=self.model, + use_sp=use_sp, + dit_fsdp=dit_fsdp, + shard_fn=shard_fn, + convert_model_dtype=convert_model_dtype) + + if use_sp: + self.sp_size = get_sequence_parallel_world_size() + else: + self.sp_size = 1 + + self.sample_neg_prompt = config.sample_neg_prompt + + def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, + convert_model_dtype): + """ + Configures a model object. This includes setting evaluation modes, + applying distributed parallel strategy, and handling device placement. + + Args: + model (torch.nn.Module): + The model instance to configure. + use_sp (`bool`): + Enable distribution strategy of sequence parallel. + dit_fsdp (`bool`): + Enable FSDP sharding for DiT model. + shard_fn (callable): + The function to apply FSDP sharding. + convert_model_dtype (`bool`): + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + + Returns: + torch.nn.Module: + The configured model. + """ + model.eval().requires_grad_(False) + + if use_sp: + for block in model.blocks: + block.self_attn.forward = types.MethodType( + sp_attn_forward, block.self_attn) + model.forward = types.MethodType(sp_dit_forward, model) + + if dist.is_initialized(): + dist.barrier() + + if dit_fsdp: + model = shard_fn(model) + else: + if convert_model_dtype: + model.to(self.param_dtype) + if not self.init_on_cpu: + model.to(self.device) + + return model + + def generate(self, + input_prompt, + img=None, + size=(1280, 704), + max_area=704 * 1280, + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + img (PIL.Image.Image): + Input image tensor. Shape: [3, H, W] + size (`tuple[int]`, *optional*, defaults to (1280,704)): + Controls video resolution, (width,height). + max_area (`int`, *optional*, defaults to 704*1280): + Maximum pixel area for latent space calculation. Controls video resolution scaling + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 50): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity. + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # i2v + if img is not None: + return self.i2v( + input_prompt=input_prompt, + img=img, + max_area=max_area, + frame_num=frame_num, + shift=shift, + sample_solver=sample_solver, + sampling_steps=sampling_steps, + guide_scale=guide_scale, + n_prompt=n_prompt, + seed=seed, + offload_model=offload_model) + # t2v + return self.t2v( + input_prompt=input_prompt, + size=size, + frame_num=frame_num, + shift=shift, + sample_solver=sample_solver, + sampling_steps=sampling_steps, + guide_scale=guide_scale, + n_prompt=n_prompt, + seed=seed, + offload_model=offload_model) + + def t2v(self, + input_prompt, + size=(1280, 704), + frame_num=121, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + size (`tuple[int]`, *optional*, defaults to (1280,704)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 121): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 50): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity. + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # preprocess + F = frame_num + target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2]) + + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + noise = [ + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ] + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with ( + torch.amp.autocast('cuda', dtype=self.param_dtype), + torch.no_grad(), + no_sync(), + ): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + mask1, mask2 = masks_like(noise, zero=False) + + arg_c = {'context': context, 'seq_len': seq_len} + arg_null = {'context': context_null, 'seq_len': seq_len} + arg_all = { + 'context': context if get_classifier_free_guidance_rank()==0 else context_null, + 'seq_len': seq_len + } + if offload_model or self.init_on_cpu: + self.model.to(self.device) + torch.cuda.empty_cache() + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = torch.stack(timestep) + + temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten() + temp_ts = torch.cat([ + temp_ts, + temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep + ]) + timestep = temp_ts.unsqueeze(0) + + if get_classifier_free_guidance_world_size() == 2: + noise_pred = self.model( + latent_model_input, t=timestep, **arg_all)[0] + noise_pred_cond, noise_pred_uncond = get_cfg_group().all_gather( + noise_pred, separate_tensors=True) + else: + noise_pred_cond = self.model( + latent_model_input, t=timestep, **arg_c)[0] + noise_pred_uncond = self.model( + latent_model_input, t=timestep, **arg_null)[0] + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + x0 = latents + if offload_model: + self.model.cpu() + torch.cuda.synchronize() + torch.cuda.empty_cache() + if self.rank == 0: + videos = self.vae.decode(x0) + + del noise, latents + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None + + def i2v(self, + input_prompt, + img, + max_area=704 * 1280, + frame_num=121, + shift=5.0, + sample_solver='unipc', + sampling_steps=40, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from input image and text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation. + img (PIL.Image.Image): + Input image tensor. Shape: [3, H, W] + max_area (`int`, *optional*, defaults to 704*1280): + Maximum pixel area for latent space calculation. Controls video resolution scaling + frame_num (`int`, *optional*, defaults to 121): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity. + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (121) + - H: Frame height (from max_area) + - W: Frame width (from max_area) + """ + # preprocess + ih, iw = img.height, img.width + dh, dw = self.patch_size[1] * self.vae_stride[1], self.patch_size[ + 2] * self.vae_stride[2] + ow, oh = best_output_size(iw, ih, dw, dh, max_area) + + scale = max(ow / iw, oh / ih) + img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS) + + # center-crop + x1 = (img.width - ow) // 2 + y1 = (img.height - oh) // 2 + img = img.crop((x1, y1, x1 + ow, y1 + oh)) + assert img.width == ow and img.height == oh + + # to tensor + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1) + + F = frame_num + seq_len = ((F - 1) // self.vae_stride[0] + 1) * ( + oh // self.vae_stride[1]) * (ow // self.vae_stride[2]) // ( + self.patch_size[1] * self.patch_size[2]) + seq_len = int(math.ceil(seq_len / self.sp_size)) * self.sp_size + + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + noise = torch.randn( + self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + oh // self.vae_stride[1], + ow // self.vae_stride[2], + dtype=torch.float32, + generator=seed_g, + device=self.device) + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + + # preprocess + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + z = self.vae.encode([img]) + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with ( + torch.amp.autocast('cuda', dtype=self.param_dtype), + torch.no_grad(), + no_sync(), + ): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latent = noise + mask1, mask2 = masks_like([noise], zero=True) + latent = (1. - mask2[0]) * z[0] + mask2[0] * latent + + arg_c = { + 'context': [context[0]], + 'seq_len': seq_len, + } + + arg_null = { + 'context': context_null, + 'seq_len': seq_len, + } + + arg_all = { + 'context': [context[0]] if get_classifier_free_guidance_rank()==0 else context_null, + 'seq_len': seq_len + } + + if offload_model or self.init_on_cpu: + self.model.to(self.device) + torch.cuda.empty_cache() + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = [latent.to(self.device)] + timestep = [t] + + timestep = torch.stack(timestep).to(self.device) + + temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten() + temp_ts = torch.cat([ + temp_ts, + temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep + ]) + timestep = temp_ts.unsqueeze(0) + + if get_classifier_free_guidance_world_size() == 2: + noise_pred = self.model( + latent_model_input, t=timestep, **arg_all)[0].to( + torch.device('cpu') if offload_model else self.device) + noise_pred_cond, noise_pred_uncond = get_cfg_group().all_gather( + noise_pred, separate_tensors=True) + if offload_model: + torch.cuda.empty_cache() + else: + noise_pred_cond = self.model( + latent_model_input, t=timestep, **arg_c)[0] + if offload_model: + torch.cuda.empty_cache() + noise_pred_uncond = self.model( + latent_model_input, t=timestep, **arg_null)[0] + if offload_model: + torch.cuda.empty_cache() + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latent.unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latent = temp_x0.squeeze(0) + latent = (1. - mask2[0]) * z[0] + mask2[0] * latent + + x0 = [latent] + del latent_model_input, timestep + + if offload_model: + self.model.cpu() + torch.cuda.synchronize() + torch.cuda.empty_cache() + + if self.rank == 0: + videos = self.vae.decode(x0) + + del noise, latent, x0 + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None diff --git a/MindIE/MultiModal/Wan2.2/wan/utils/__init__.py b/MindIE/MultiModal/Wan2.2/wan/utils/__init__.py new file mode 100644 index 0000000000..5b173105eb --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/utils/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from .fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) +from .fm_solvers_unipc import FlowUniPCMultistepScheduler + +__all__ = [ + 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', + 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler' +] diff --git a/MindIE/MultiModal/Wan2.2/wan/utils/fm_solvers.py b/MindIE/MultiModal/Wan2.2/wan/utils/fm_solvers.py new file mode 100644 index 0000000000..17bef85000 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/utils/fm_solvers.py @@ -0,0 +1,859 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +# Convert dpm solver for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import inspect +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate, is_scipy_available +from diffusers.utils.torch_utils import randn_tensor + +if is_scipy_available(): + pass + + +def get_sampling_sigmas(sampling_steps, shift): + sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] + sigma = (shift * sigma / (1 + (shift - 1) * sigma)) + + return sigma + + +def retrieve_timesteps( + scheduler, + num_inference_steps=None, + device=None, + timesteps=None, + sigmas=None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. This determines the resolution of the diffusion process. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored + and used in multistep updates. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + shift (`float`, *optional*, defaults to 1.0): + A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling + process. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is + applied on the fly. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent + saturation and improve photorealism. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, *optional*, defaults to "zero"): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + invert_sigmas: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", + deprecation_message) + + # settings for DPM-Solver + if algorithm_type not in [ + "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++" + ]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError( + f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++" + ] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + self._step_index = None + self._begin_index = None + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / + sigma_s) * sample - (alpha_t * + (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / + alpha_s) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ((alpha_t / alpha_s) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * + (alpha_t * (torch.exp(-h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * + (sigma_t * (torch.exp(h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * + (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / + (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + self.sigmas[self.step_index - 2], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[ + -2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - + (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((alpha_t / alpha_s0) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2) + return x_t # pyright: ignore + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final or + (self.config.lower_order_final and len(self.timesteps) < 15) or + self.config.final_sigmas_type == "zero") + lower_order_second = ((self.step_index == len(self.timesteps) - 2) and + self.config.lower_order_final and + len(self.timesteps) < 15) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++" + ] and variance_noise is None: + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=torch.float32) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to( + device=model_output.device, + dtype=torch.float32) # pyright: ignore + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update( + model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.Tensor`): + The input sample. + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/MindIE/MultiModal/Wan2.2/wan/utils/fm_solvers_unipc.py b/MindIE/MultiModal/Wan2.2/wan/utils/fm_solvers_unipc.py new file mode 100644 index 0000000000..adb5206f25 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/utils/fm_solvers_unipc.py @@ -0,0 +1,804 @@ +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError( + " missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + R_inv = torch.inverse(R[:-1, :-1]) + rhos_p = torch.matmul(R_inv, b[:-1]).to(device).to(x.dtype) + # rhos_p = torch.linalg.solve(R[:-1, :-1], + # b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop( + "this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError( + " missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError( + " missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError( + " missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ + self.step_index - 1] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + R_inv = torch.inverse(R) + rhos_c = torch.matmul(R_inv, b).to(device).to(x.dtype) + # rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step(self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and + self.step_index - 1 not in self.disable_corrector and + self.last_sample is not None # pyright: ignore + ) + + model_output_convert = self.convert_model_output( + model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, + len(self.timesteps) - + self.step_index) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, + self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/MindIE/MultiModal/Wan2.2/wan/utils/prompt_extend.py b/MindIE/MultiModal/Wan2.2/wan/utils/prompt_extend.py new file mode 100644 index 0000000000..9d40d9c8b0 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/utils/prompt_extend.py @@ -0,0 +1,542 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import json +import logging +import math +import os +import random +import sys +import tempfile +from dataclasses import dataclass +from http import HTTPStatus +from typing import Optional, Union + +import dashscope +import torch +from PIL import Image + +try: + from flash_attn import flash_attn_varlen_func + FLASH_VER = 2 +except ModuleNotFoundError: + flash_attn_varlen_func = None # in compatible with CPU machines + FLASH_VER = None + +from .system_prompt import * + +DEFAULT_SYS_PROMPTS = { + "t2v-A14B": { + "zh": T2V_A14B_ZH_SYS_PROMPT, + "en": T2V_A14B_EN_SYS_PROMPT, + }, + "i2v-A14B": { + "zh": I2V_A14B_ZH_SYS_PROMPT, + "en": I2V_A14B_EN_SYS_PROMPT, + "empty": { + "zh": I2V_A14B_EMPTY_ZH_SYS_PROMPT, + "en": I2V_A14B_EMPTY_EN_SYS_PROMPT, + } + }, + "ti2v-5B": { + "t2v": { + "zh": T2V_A14B_ZH_SYS_PROMPT, + "en": T2V_A14B_EN_SYS_PROMPT, + }, + "i2v": { + "zh": I2V_A14B_ZH_SYS_PROMPT, + "en": I2V_A14B_EN_SYS_PROMPT, + } + }, +} + + +@dataclass +class PromptOutput(object): + status: bool + prompt: str + seed: int + system_prompt: str + message: str + + def add_custom_field(self, key: str, value) -> None: + self.__setattr__(key, value) + + +class PromptExpander: + + def __init__(self, model_name, task, is_vl=False, device=0, **kwargs): + self.model_name = model_name + self.task = task + self.is_vl = is_vl + self.device = device + + def extend_with_img(self, + prompt, + system_prompt, + image=None, + seed=-1, + *args, + **kwargs): + pass + + def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs): + pass + + def decide_system_prompt(self, tar_lang="zh", prompt=None): + assert self.task is not None + if "ti2v" in self.task: + if self.is_vl: + return DEFAULT_SYS_PROMPTS[self.task]["i2v"][tar_lang] + else: + return DEFAULT_SYS_PROMPTS[self.task]["t2v"][tar_lang] + if "i2v" in self.task and len(prompt) == 0: + return DEFAULT_SYS_PROMPTS[self.task]["empty"][tar_lang] + return DEFAULT_SYS_PROMPTS[self.task][tar_lang] + + def __call__(self, + prompt, + system_prompt=None, + tar_lang="zh", + image=None, + seed=-1, + *args, + **kwargs): + if system_prompt is None: + system_prompt = self.decide_system_prompt( + tar_lang=tar_lang, prompt=prompt) + if seed < 0: + seed = random.randint(0, sys.maxsize) + if image is not None and self.is_vl: + return self.extend_with_img( + prompt, system_prompt, image=image, seed=seed, *args, **kwargs) + elif not self.is_vl: + return self.extend(prompt, system_prompt, seed, *args, **kwargs) + else: + raise NotImplementedError + + +class DashScopePromptExpander(PromptExpander): + + def __init__(self, + api_key=None, + model_name=None, + task=None, + max_image_size=512 * 512, + retry_times=4, + is_vl=False, + **kwargs): + ''' + Args: + api_key: The API key for Dash Scope authentication and access to related services. + model_name: Model name, 'qwen-plus' for extending prompts, 'qwen-vl-max' for extending prompt-images. + task: Task name. This is required to determine the default system prompt. + max_image_size: The maximum size of the image; unit unspecified (e.g., pixels, KB). Please specify the unit based on actual usage. + retry_times: Number of retry attempts in case of request failure. + is_vl: A flag indicating whether the task involves visual-language processing. + **kwargs: Additional keyword arguments that can be passed to the function or method. + ''' + if model_name is None: + model_name = 'qwen-plus' if not is_vl else 'qwen-vl-max' + super().__init__(model_name, task, is_vl, **kwargs) + if api_key is not None: + dashscope.api_key = api_key + elif 'DASH_API_KEY' in os.environ and os.environ[ + 'DASH_API_KEY'] is not None: + dashscope.api_key = os.environ['DASH_API_KEY'] + else: + raise ValueError("DASH_API_KEY is not set") + if 'DASH_API_URL' in os.environ and os.environ[ + 'DASH_API_URL'] is not None: + dashscope.base_http_api_url = os.environ['DASH_API_URL'] + else: + dashscope.base_http_api_url = 'https://dashscope.aliyuncs.com/api/v1' + self.api_key = api_key + + self.max_image_size = max_image_size + self.model = model_name + self.retry_times = retry_times + + def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs): + messages = [{ + 'role': 'system', + 'content': system_prompt + }, { + 'role': 'user', + 'content': prompt + }] + + exception = None + for _ in range(self.retry_times): + try: + response = dashscope.Generation.call( + self.model, + messages=messages, + seed=seed, + result_format='message', # set the result to be "message" format. + ) + assert response.status_code == HTTPStatus.OK, response + expanded_prompt = response['output']['choices'][0]['message'][ + 'content'] + return PromptOutput( + status=True, + prompt=expanded_prompt, + seed=seed, + system_prompt=system_prompt, + message=json.dumps(response, ensure_ascii=False)) + except Exception as e: + exception = e + return PromptOutput( + status=False, + prompt=prompt, + seed=seed, + system_prompt=system_prompt, + message=str(exception)) + + def extend_with_img(self, + prompt, + system_prompt, + image: Union[Image.Image, str] = None, + seed=-1, + *args, + **kwargs): + if isinstance(image, str): + image = Image.open(image).convert('RGB') + w = image.width + h = image.height + area = min(w * h, self.max_image_size) + aspect_ratio = h / w + resized_h = round(math.sqrt(area * aspect_ratio)) + resized_w = round(math.sqrt(area / aspect_ratio)) + image = image.resize((resized_w, resized_h)) + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + image.save(f.name) + fname = f.name + image_path = f"file://{f.name}" + prompt = f"{prompt}" + messages = [ + { + 'role': 'system', + 'content': [{ + "text": system_prompt + }] + }, + { + 'role': 'user', + 'content': [{ + "text": prompt + }, { + "image": image_path + }] + }, + ] + response = None + result_prompt = prompt + exception = None + status = False + for _ in range(self.retry_times): + try: + response = dashscope.MultiModalConversation.call( + self.model, + messages=messages, + seed=seed, + result_format='message', # set the result to be "message" format. + ) + assert response.status_code == HTTPStatus.OK, response + result_prompt = response['output']['choices'][0]['message'][ + 'content'][0]['text'].replace('\n', '\\n') + status = True + break + except Exception as e: + exception = e + result_prompt = result_prompt.replace('\n', '\\n') + os.remove(fname) + + return PromptOutput( + status=status, + prompt=result_prompt, + seed=seed, + system_prompt=system_prompt, + message=str(exception) if not status else json.dumps( + response, ensure_ascii=False)) + + +class QwenPromptExpander(PromptExpander): + model_dict = { + "QwenVL2.5_3B": "Qwen/Qwen2.5-VL-3B-Instruct", + "QwenVL2.5_7B": "Qwen/Qwen2.5-VL-7B-Instruct", + "Qwen2.5_3B": "Qwen/Qwen2.5-3B-Instruct", + "Qwen2.5_7B": "Qwen/Qwen2.5-7B-Instruct", + "Qwen2.5_14B": "Qwen/Qwen2.5-14B-Instruct", + } + + def __init__(self, + model_name=None, + task=None, + device=0, + is_vl=False, + **kwargs): + ''' + Args: + model_name: Use predefined model names such as 'QwenVL2.5_7B' and 'Qwen2.5_14B', + which are specific versions of the Qwen model. Alternatively, you can use the + local path to a downloaded model or the model name from Hugging Face." + Detailed Breakdown: + Predefined Model Names: + * 'QwenVL2.5_7B' and 'Qwen2.5_14B' are specific versions of the Qwen model. + Local Path: + * You can provide the path to a model that you have downloaded locally. + Hugging Face Model Name: + * You can also specify the model name from Hugging Face's model hub. + task: Task name. This is required to determine the default system prompt. + is_vl: A flag indicating whether the task involves visual-language processing. + **kwargs: Additional keyword arguments that can be passed to the function or method. + ''' + if model_name is None: + model_name = 'Qwen2.5_14B' if not is_vl else 'QwenVL2.5_7B' + super().__init__(model_name, task, is_vl, device, **kwargs) + if (not os.path.exists(self.model_name)) and (self.model_name + in self.model_dict): + self.model_name = self.model_dict[self.model_name] + + if self.is_vl: + # default: Load the model on the available device(s) + from transformers import ( + AutoProcessor, + AutoTokenizer, + Qwen2_5_VLForConditionalGeneration, + ) + try: + from .qwen_vl_utils import process_vision_info + except: + from qwen_vl_utils import process_vision_info + self.process_vision_info = process_vision_info + min_pixels = 256 * 28 * 28 + max_pixels = 1280 * 28 * 28 + self.processor = AutoProcessor.from_pretrained( + self.model_name, + min_pixels=min_pixels, + max_pixels=max_pixels, + use_fast=True) + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + self.model_name, + torch_dtype=torch.bfloat16 if FLASH_VER == 2 else + torch.float16 if "AWQ" in self.model_name else "auto", + attn_implementation="flash_attention_2" + if FLASH_VER == 2 else None, + device_map="cpu") + else: + from transformers import AutoModelForCausalLM, AutoTokenizer + self.model = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=torch.float16 + if "AWQ" in self.model_name else "auto", + attn_implementation="flash_attention_2" + if FLASH_VER == 2 else None, + device_map="cpu") + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs): + self.model = self.model.to(self.device) + messages = [{ + "role": "system", + "content": system_prompt + }, { + "role": "user", + "content": prompt + }] + text = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True) + model_inputs = self.tokenizer([text], + return_tensors="pt").to(self.model.device) + + generated_ids = self.model.generate(**model_inputs, max_new_tokens=512) + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip( + model_inputs.input_ids, generated_ids) + ] + + expanded_prompt = self.tokenizer.batch_decode( + generated_ids, skip_special_tokens=True)[0] + self.model = self.model.to("cpu") + return PromptOutput( + status=True, + prompt=expanded_prompt, + seed=seed, + system_prompt=system_prompt, + message=json.dumps({"content": expanded_prompt}, + ensure_ascii=False)) + + def extend_with_img(self, + prompt, + system_prompt, + image: Union[Image.Image, str] = None, + seed=-1, + *args, + **kwargs): + self.model = self.model.to(self.device) + messages = [{ + 'role': 'system', + 'content': [{ + "type": "text", + "text": system_prompt + }] + }, { + "role": + "user", + "content": [ + { + "type": "image", + "image": image, + }, + { + "type": "text", + "text": prompt + }, + ], + }] + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True) + image_inputs, video_inputs = self.process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to(self.device) + + # Inference: Generation of the output + generated_ids = self.model.generate(**inputs, max_new_tokens=512) + generated_ids_trimmed = [ + out_ids[len(in_ids):] + for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + expanded_prompt = self.processor.batch_decode( + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False)[0] + self.model = self.model.to("cpu") + return PromptOutput( + status=True, + prompt=expanded_prompt, + seed=seed, + system_prompt=system_prompt, + message=json.dumps({"content": expanded_prompt}, + ensure_ascii=False)) + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)]) + + seed = 100 + prompt = "夏日海滩度假风格,一只戴着墨镜的白色猫咪坐在冲浪板上。猫咪毛发蓬松,表情悠闲,直视镜头。背景是模糊的海滩景色,海水清澈,远处有绿色的山丘和蓝天白云。猫咪的姿态自然放松,仿佛在享受海风和阳光。近景特写,强调猫咪的细节和海滩的清新氛围。" + en_prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." + image = "./examples/i2v_input.JPG" + + def test(method, + prompt, + model_name, + task, + image=None, + en_prompt=None, + seed=None): + prompt_expander = method( + model_name=model_name, task=task, is_vl=image is not None) + result = prompt_expander(prompt, image=image, tar_lang="zh") + logging.info(f"zh prompt -> zh: {result.prompt}") + result = prompt_expander(prompt, image=image, tar_lang="en") + logging.info(f"zh prompt -> en: {result.prompt}") + if en_prompt is not None: + result = prompt_expander(en_prompt, image=image, tar_lang="zh") + logging.info(f"en prompt -> zh: {result.prompt}") + result = prompt_expander(en_prompt, image=image, tar_lang="en") + logging.info(f"en prompt -> en: {result.prompt}") + + ds_model_name = None + ds_vl_model_name = None + qwen_model_name = None + qwen_vl_model_name = None + + for task in ["t2v-A14B", "i2v-A14B", "ti2v-5B"]: + # test prompt extend + if "t2v" in task or "ti2v" in task: + # test dashscope api + logging.info(f"-" * 40) + logging.info(f"Testing {task} dashscope prompt extend") + test( + DashScopePromptExpander, + prompt, + ds_model_name, + task, + image=None, + en_prompt=en_prompt, + seed=seed) + + # test qwen api + logging.info(f"-" * 40) + logging.info(f"Testing {task} qwen prompt extend") + test( + QwenPromptExpander, + prompt, + qwen_model_name, + task, + image=None, + en_prompt=en_prompt, + seed=seed) + + # test prompt-image extend + if "i2v" in task: + # test dashscope api + logging.info(f"-" * 40) + logging.info(f"Testing {task} dashscope vl prompt extend") + test( + DashScopePromptExpander, + prompt, + ds_vl_model_name, + task, + image=image, + en_prompt=en_prompt, + seed=seed) + + # test qwen api + logging.info(f"-" * 40) + logging.info(f"Testing {task} qwen vl prompt extend") + test( + QwenPromptExpander, + prompt, + qwen_vl_model_name, + task, + image=image, + en_prompt=en_prompt, + seed=seed) + + # test empty prompt extend + if "i2v-A14B" in task: + # test dashscope api + logging.info(f"-" * 40) + logging.info(f"Testing {task} dashscope vl empty prompt extend") + test( + DashScopePromptExpander, + "", + ds_vl_model_name, + task, + image=image, + en_prompt=None, + seed=seed) + + # test qwen api + logging.info(f"-" * 40) + logging.info(f"Testing {task} qwen vl empty prompt extend") + test( + QwenPromptExpander, + "", + qwen_vl_model_name, + task, + image=image, + en_prompt=None, + seed=seed) diff --git a/MindIE/MultiModal/Wan2.2/wan/utils/qwen_vl_utils.py b/MindIE/MultiModal/Wan2.2/wan/utils/qwen_vl_utils.py new file mode 100644 index 0000000000..bf0e832861 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/utils/qwen_vl_utils.py @@ -0,0 +1,363 @@ +# Copied from https://github.com/kq-chen/qwen-vl-utils +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from __future__ import annotations + +import base64 +import logging +import math +import os +import sys +import time +import warnings +from functools import lru_cache +from io import BytesIO + +import requests +import torch +import torchvision +from packaging import version +from PIL import Image +from torchvision import io, transforms +from torchvision.transforms import InterpolationMode + +logger = logging.getLogger(__name__) + +IMAGE_FACTOR = 28 +MIN_PIXELS = 4 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + +VIDEO_MIN_PIXELS = 128 * 28 * 28 +VIDEO_MAX_PIXELS = 768 * 28 * 28 +VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 +FRAME_FACTOR = 2 +FPS = 2.0 +FPS_MIN_FRAMES = 4 +FPS_MAX_FRAMES = 768 + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def smart_resize(height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def fetch_image(ele: dict[str, str | Image.Image], + size_factor: int = IMAGE_FACTOR) -> Image.Image: + if "image" in ele: + image = ele["image"] + else: + image = ele["image_url"] + image_obj = None + if isinstance(image, Image.Image): + image_obj = image + elif image.startswith("http://") or image.startswith("https://"): + image_obj = Image.open(requests.get(image, stream=True).raw) + elif image.startswith("file://"): + image_obj = Image.open(image[7:]) + elif image.startswith("data:image"): + if "base64," in image: + _, base64_data = image.split("base64,", 1) + data = base64.b64decode(base64_data) + image_obj = Image.open(BytesIO(data)) + else: + image_obj = Image.open(image) + if image_obj is None: + raise ValueError( + f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}" + ) + image = image_obj.convert("RGB") + ## resize + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=size_factor, + ) + else: + width, height = image.size + min_pixels = ele.get("min_pixels", MIN_PIXELS) + max_pixels = ele.get("max_pixels", MAX_PIXELS) + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +def smart_nframes( + ele: dict, + total_frames: int, + video_fps: int | float, +) -> int: + """calculate the number of frames for video used for model inputs. + + Args: + ele (dict): a dict contains the configuration of video. + support either `fps` or `nframes`: + - nframes: the number of frames to extract for model inputs. + - fps: the fps to extract frames for model inputs. + - min_frames: the minimum number of frames of the video, only used when fps is provided. + - max_frames: the maximum number of frames of the video, only used when fps is provided. + total_frames (int): the original total number of frames of the video. + video_fps (int | float): the original fps of the video. + + Raises: + ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. + + Returns: + int: the number of frames for video used for model inputs. + """ + assert not ("fps" in ele and + "nframes" in ele), "Only accept either `fps` or `nframes`" + if "nframes" in ele: + nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) + else: + fps = ele.get("fps", FPS) + min_frames = ceil_by_factor( + ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) + max_frames = floor_by_factor( + ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), + FRAME_FACTOR) + nframes = total_frames / video_fps * fps + nframes = min(max(nframes, min_frames), max_frames) + nframes = round_by_factor(nframes, FRAME_FACTOR) + if not (FRAME_FACTOR <= nframes and nframes <= total_frames): + raise ValueError( + f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}." + ) + return nframes + + +def _read_video_torchvision(ele: dict,) -> torch.Tensor: + """read video using torchvision.io.read_video + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + video_path = ele["video"] + if version.parse(torchvision.__version__) < version.parse("0.19.0"): + if "http://" in video_path or "https://" in video_path: + warnings.warn( + "torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0." + ) + if "file://" in video_path: + video_path = video_path[7:] + st = time.time() + video, audio, info = io.read_video( + video_path, + start_pts=ele.get("video_start", 0.0), + end_pts=ele.get("video_end", None), + pts_unit="sec", + output_format="TCHW", + ) + total_frames, video_fps = video.size(0), info["video_fps"] + logger.info( + f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s" + ) + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + idx = torch.linspace(0, total_frames - 1, nframes).round().long() + video = video[idx] + return video + + +def is_decord_available() -> bool: + import importlib.util + + return importlib.util.find_spec("decord") is not None + + +def _read_video_decord(ele: dict,) -> torch.Tensor: + """read video using decord.VideoReader + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + import decord + video_path = ele["video"] + st = time.time() + vr = decord.VideoReader(video_path) + # TODO: support start_pts and end_pts + if 'video_start' in ele or 'video_end' in ele: + raise NotImplementedError( + "not support start_pts and end_pts in decord for now.") + total_frames, video_fps = len(vr), vr.get_avg_fps() + logger.info( + f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s" + ) + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() + video = vr.get_batch(idx).asnumpy() + video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format + return video + + +VIDEO_READER_BACKENDS = { + "decord": _read_video_decord, + "torchvision": _read_video_torchvision, +} + +FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) + + +@lru_cache(maxsize=1) +def get_video_reader_backend() -> str: + if FORCE_QWENVL_VIDEO_READER is not None: + video_reader_backend = FORCE_QWENVL_VIDEO_READER + elif is_decord_available(): + video_reader_backend = "decord" + else: + video_reader_backend = "torchvision" + logger.info( + f"qwen-vl-utils using {video_reader_backend} to read video.", + file=sys.stderr) + return video_reader_backend + + +def fetch_video( + ele: dict, + image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]: + if isinstance(ele["video"], str): + video_reader_backend = get_video_reader_backend() + video = VIDEO_READER_BACKENDS[video_reader_backend](ele) + nframes, _, height, width = video.shape + + min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) + total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) + max_pixels = max( + min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), + int(min_pixels * 1.05)) + max_pixels = ele.get("max_pixels", max_pixels) + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=image_factor, + ) + else: + resized_height, resized_width = smart_resize( + height, + width, + factor=image_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + video = transforms.functional.resize( + video, + [resized_height, resized_width], + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ).float() + return video + else: + assert isinstance(ele["video"], (list, tuple)) + process_info = ele.copy() + process_info.pop("type", None) + process_info.pop("video", None) + images = [ + fetch_image({ + "image": video_element, + **process_info + }, + size_factor=image_factor) + for video_element in ele["video"] + ] + nframes = ceil_by_factor(len(images), FRAME_FACTOR) + if len(images) < nframes: + images.extend([images[-1]] * (nframes - len(images))) + return images + + +def extract_vision_info( + conversations: list[dict] | list[list[dict]]) -> list[dict]: + vision_infos = [] + if isinstance(conversations[0], dict): + conversations = [conversations] + for conversation in conversations: + for message in conversation: + if isinstance(message["content"], list): + for ele in message["content"]: + if ("image" in ele or "image_url" in ele or + "video" in ele or + ele["type"] in ("image", "image_url", "video")): + vision_infos.append(ele) + return vision_infos + + +def process_vision_info( + conversations: list[dict] | list[list[dict]], +) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | + None]: + vision_infos = extract_vision_info(conversations) + ## Read images or videos + image_inputs = [] + video_inputs = [] + for vision_info in vision_infos: + if "image" in vision_info or "image_url" in vision_info: + image_inputs.append(fetch_image(vision_info)) + elif "video" in vision_info: + video_inputs.append(fetch_video(vision_info)) + else: + raise ValueError("image, image_url or video should in content.") + if len(image_inputs) == 0: + image_inputs = None + if len(video_inputs) == 0: + video_inputs = None + return image_inputs, video_inputs diff --git a/MindIE/MultiModal/Wan2.2/wan/utils/system_prompt.py b/MindIE/MultiModal/Wan2.2/wan/utils/system_prompt.py new file mode 100644 index 0000000000..c494705555 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/utils/system_prompt.py @@ -0,0 +1,147 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +T2V_A14B_ZH_SYS_PROMPT = \ +''' 你是一位电影导演,旨在为用户输入的原始prompt添加电影元素,改写为优质Prompt,使其完整、具有表现力。 +任务要求: +1. 对于用户输入的prompt,在不改变prompt的原意(如主体、动作)前提下,从下列电影美学设定中选择部分合适的时间、光源、光线强度、光线角度、对比度、饱和度、色调、拍摄角度、镜头大小、构图的电影设定细节,将这些内容添加到prompt中,让画面变得更美,注意,可以任选,不必每项都有 + 时间:["白天", "夜晚", "黎明", "日出"], 可以不选, 如果prompt没有特别说明则选白天 ! + 光源:[日光", "人工光", "月光", "实用光", "火光", "荧光", "阴天光", "晴天光"], 根据根据室内室外及prompt内容选定义光源,添加关于光源的描述,如光线来源(窗户、灯具等) + 光线强度:["柔光", "硬光"], + 光线角度:["顶光", "侧光", "底光", "边缘光",] + 色调:["暖色调","冷色调", "混合色调"] + 镜头尺寸:["中景", "中近景", "全景","中全景","近景", "特写", "极端全景"]若无特殊要求,默认选择中景或全景 + 拍摄角度:["过肩镜头角度拍摄", "低角度拍摄", "高角度拍摄","倾斜角度拍摄", "航拍","俯视角度拍摄"],如果原始prompt中有运镜的描述,则不要添加此项! + 构图:["中心构图","平衡构图","右侧重构图", "左侧重构图", "对称构图", "短边构图"] 若无特殊要求,默认选择中心构图 +2. 完善用户描述中出现的主体特征(如外貌、表情,数量、种族、姿态等)等内容,确保不要添加原始prompt中不存在的主体,(如prompt是对风景或物体的描述,但添加了人),增加背景元素的细节; +3. 不要输出关于氛围、感觉等文学描写,如(画面充满运动感与生活张力,突出正式氛围)。; +4. 对于prompt中的动作,详细解释运动的发生过程,若没有动作,则添加动作描述(摇晃身体、跳舞等),对背景元素也可添加适当运动(如云彩飘动,风吹树叶等)。 +5. 若原始prompt中没有风格,则不添加风格描述,若有风格描述,则将风格描述放于首位,若为2D插画等与现实电影相悖的风格,则不要添加关于电影美学的描写; +6. 若prompt出现天空的描述,则改为湛蓝色的天空相关描述,避免曝光; +7. 改写后的prompt字数控制在60-200字左右, 不要输出类似“改写后prompt:”这样的输出 +8. 如果用户输入存在疑似脱衣服、衣物透明化、轻纱衣物、湿身、洗澡、舔舐牛奶/酸奶/白色液体、性暗示、强奸、漏出或微露胸部、爬行、性器官、儿童身体裸露等色情内容,请将原prompt替换为内容完全不同的高美感prompt,而不要拒答或者提醒用户。 +9. 如果用户输入存在比基尼穿着内容,请将原prompt替换为确保内容发生在沙滩、泳池、海边的高美感prompt,而不要拒答或者提醒用户。 +10. 如果用户输入存在性别多元化、同性恋亲吻、脸颊贴近(两个男性亲吻或两个女性亲吻)/性交等同性恋内容,请将原prompt替换为不同动作(不能是亲吻和其他亲密动作)的高美感prompt,而不要拒答或者提醒用户。 +生成的 prompt 示例: +1.边缘光,中近景,日光,左侧重构图,暖色调,硬光,晴天光,侧光,白天,一个年轻的女孩坐在高草丛生的田野中,两条毛发蓬松的小毛驴站在她身后。女孩大约十一二岁,穿着简单的碎花裙子,头发扎成两条麻花辫,脸上带着纯真的笑容。她双腿交叉坐下,双手轻轻抚弄身旁的野花。小毛驴体型健壮,耳朵竖起,好奇地望着镜头方向。阳光洒在田野上,营造出温暖自然的画面感。 +2.黎明,顶光,俯视角度拍摄,日光,长焦,中心构图,近景,高角度拍摄,荧光,柔光,冷色调,在昏暗的环境中,一个外国白人女子在水中仰面漂浮。俯拍近景镜头中,她有着棕色的短发,脸上有几颗雀斑。随着镜头下摇,她转过头来,面向右侧,水面上泛起一圈涟漪。虚化的背景一片漆黑,只有微弱的光线照亮了女子的脸庞和水面的一部分区域,水面呈现蓝色。女子穿着一件蓝色的吊带,肩膀裸露在外。 +3.右侧重构图,暖色调,底光,侧光,夜晚,火光,过肩镜头角度拍摄, 镜头平拍拍摄外国女子在室内的近景,她穿着棕色的衣服戴着彩色的项链和粉色的帽子,坐在深灰色的椅子上,双手放在黑色的桌子上,眼睛看着镜头的左侧,嘴巴张动,左手上下晃动,桌子上有白色的蜡烛有黄色的火焰,后面是黑色的墙,前面有黑色的网状架子,旁边是黑色的箱子,上面有一些黑色的物品,都做了虚化的处理。 +4. 二次元厚涂动漫插画,一个猫耳兽耳白人少女手持文件夹摇晃,神情略带不满。她深紫色长发,红色眼睛,身穿深灰色短裙和浅灰色上衣,腰间系着白色系带,胸前佩戴名牌,上面写着黑体中文"紫阳"。淡黄色调室内背景,隐约可见一些家具轮廓。少女头顶有一个粉色光圈。线条流畅的日系赛璐璐风格。近景半身略俯视视角。 +''' + + +T2V_A14B_EN_SYS_PROMPT = \ +'''你是一位电影导演,旨在为用户输入的原始prompt添加电影元素,改写为优质(英文)Prompt,使其完整、具有表现力注意,输出必须是英文! +任务要求: +1. 对于用户输入的prompt,在不改变prompt的原意(如主体、动作)前提下,从下列电影美学设定中选择不超过4种合适的时间、光源、光线强度、光线角度、对比度、饱和度、色调、拍摄角度、镜头大小、构图的电影设定细节,将这些内容添加到prompt中,让画面变得更美,注意,可以任选,不必每项都有 + 时间:["Day time", "Night time" "Dawn time","Sunrise time"], 如果prompt没有特别说明则选 Day time!!! + 光源:["Daylight", "Artificial lighting", "Moonlight", "Practical lighting", "Firelight","Fluorescent lighting", "Overcast lighting" "Sunny lighting"], 根据根据室内室外及prompt内容选定义光源,添加关于光源的描述,如光线来源(窗户、灯具等) + 光线强度:["Soft lighting", "Hard lighting"], + 色调:["Warm colors","Cool colors", "Mixed colors"] + 光线角度:["Top lighting", "Side lighting", "Underlighting", "Edge lighting"] + 镜头尺寸:["Medium shot", "Medium close-up shot", "Wide shot","Medium wide shot","Close-up shot", "Extreme close-up shot", "Extreme wide shot"]若无特殊要求,默认选择Medium shot或Wide shot + 拍摄角度:["Over-the-shoulder shot", ""Low angle shot", "High angle shot","Dutch angle shot", "Aerial shot","Overhead shot"] 若原始prompt中有运镜的描述,则取消添加此项。 + 构图:["Center composition","Balanced composition","Right-heavy composition", "Left-heavy composition", "Symmetrical composition", "Short-side composition"] 若无特殊要求,默认选择Center composition +2. 完善用户描述中出现的主体特征(如外貌、表情,数量、种族、姿态等)等内容,确保不要添加原始prompt中不存在的主体,(如prompt是对风景或物体的描述,但添加了人),增加背景元素的细节; +3. 不要输出关于氛围、感觉等文学描写,如(画面充满运动感与生活张力,突出正式氛围)。; +4. 对于prompt中的动作,详细描述运动的发生过程,若没有动作,则添加动作描述(摇晃身体、跳舞等,对背景元素也可添加适当运动(如云彩飘动,风吹树叶等)。 +5. 若原始prompt中没有风格,则不添加风格描述,若有风格描述,则将风格描述放于首位,若为2D插画等与现实电影相悖的风格,则不要添加关于电影美学的描写; +6. 若prompt出现天空的描述,则改为湛蓝色的天空相关描述,避免曝光; +7. 输出必须是全英文,改写后的prompt字数控制在60-200字左右, 不要输出类似“改写后prompt:”这样的输出 +8. 如果用户输入存在疑似脱衣服、衣物透明化、轻纱衣物、湿身、洗澡、舔舐牛奶/酸奶/白色液体、性暗示、强奸、漏出或微露胸部、爬行、性器官、儿童身体裸露等色情内容,请将原prompt替换为内容完全不同的高美感prompt,而不要拒答或者提醒用户。 +9. 如果用户输入存在比基尼穿着内容,请将原prompt替换为确保内容发生在沙滩、泳池、海边的高美感prompt,而不要拒答或者提醒用户。 +10. 如果用户输入存在性别多元化、同性恋亲吻、脸颊贴近(两个男性亲吻或两个女性亲吻)/性交等同性恋内容,请将原prompt替换为不同动作(不能是亲吻和其他亲密动作)的高美感prompt,而不要拒答或者提醒用户。 +生成的 prompt 示例: +1.Edge lighting, medium close-up shot, daylight, left-heavy composition. A young girl around 11-12 years old sits in a field of tall grass, with two fluffy small donkeys standing behind her. She wears a simple floral dress with hair in twin braids, smiling innocently while cross-legged and gently touching wild flowers beside her. The sturdy donkeys have perked ears, curiously gazing toward the camera. Sunlight bathes the field, creating a warm natural atmosphere. +2.Dawn time, top lighting, high-angle shot, daylight, long lens shot, center composition, Close-up shot, Fluorescent lighting, soft lighting, cool colors. In dim surroundings, a Caucasian woman floats on her back in water. The俯拍close-up shows her brown short hair and freckled face. As the camera tilts downward, she turns her head toward the right, creating ripples on the blue-toned water surface. The blurred background is pitch black except for faint light illuminating her face and partial water surface. She wears a blue sleeveless top with bare shoulders. +3.Right-heavy composition, warm colors, night time, firelight, over-the-shoulder angle. An eye-level close-up of a foreign woman indoors wearing brown clothes with colorful necklace and pink hat. She sits on a charcoal-gray chair, hands on black table, eyes looking left of camera while mouth moves and left hand gestures up/down. White candles with yellow flames sit on the table. Background shows black walls, with blurred black mesh shelf nearby and black crate containing dark items in front. +4."Anime-style thick-painted style. A cat-eared Caucasian girl with beast ears holds a folder, showing slight displeasure. Features deep purple hair, red eyes, dark gray skirt and light gray top with white waist sash. A name tag labeled 'Ziyang' in bold Chinese characters hangs on her chest. Pale yellow indoor background with faint furniture outlines. A pink halo floats above her head. Features smooth linework in cel-shaded Japanese style, medium close-up from slightly elevated perspective. +''' + + +I2V_A14B_ZH_SYS_PROMPT = \ +'''你是一个视频描述提示词的改写专家,你的任务是根据用户给你输入的图像,对提供的视频描述提示词进行改写,你要强调潜在的动态内容。具体要求如下 +用户输入的语言可能含有多样化的描述,如markdown文档格式、指令格式,长度过长或者过短,你需要根据图片的内容和用户的输入的提示词,尽可能提取用户输入的提示词和图片关联信息。 +你改写的视频描述结果要尽可能保留提供给你的视频描述提示词中动态部分,保留主体的动作。 +你要根据图像,强调并简化视频描述提示词中的图像主体,如果用户只提供了动作,你要根据图像内容合理补充,如“跳舞”补充称“一个女孩在跳舞” +如果用户输入的提示词过长,你需要提炼潜在的动作过程 +如果用户输入的提示词过短,综合用户输入的提示词以及画面内容,合理的增加潜在的运动信息 +你要根据图像,保留并强调视频描述提示词中关于运镜手段的描述,如“镜头上摇”,“镜头从左到右”,“镜头从右到左”等等,你要保留,如“镜头拍摄两个男人打斗,他们先是躺在地上,随后镜头向上移动,拍摄他们站起来,接着镜头向左移动,左边男人拿着一个蓝色的东西,右边男人上前抢夺,两人激烈地来回争抢。”。 +你需要给出对视频描述的动态内容,不要添加对于静态场景的描述,如果用户输入的描述已经在画面中出现,则移除这些描述 +改写后的prompt字数控制在100字以下 +无论用户输入那种语言,你都需要输出中文 +改写后 prompt 示例: +1. 镜头后拉,拍摄两个外国男人,走在楼梯上,镜头左侧的男人右手搀扶着镜头右侧的男人。 +2. 一只黑色的小松鼠专注地吃着东西,偶尔抬头看看四周。 +3. 男子说着话,表情从微笑逐渐转变为闭眼,然后睁开眼睛,最后是闭眼微笑,他的手势活跃,在说话时做出一系列的手势。 +4. 一个人正在用尺子和笔进行测量的特写,右手用一支黑色水性笔在纸上画出一条直线。 +5. 一辆车模型在木板上形式,车辆从画面的右侧向左侧移动,经过一片草地和一些木制结构。 +6. 镜头左移后前推,拍摄一个人坐在防波堤上。 +7. 男子说着话,他的表情和手势随着对话内容的变化而变化,但整体场景保持不变。 +8. 镜头左移后前推,拍摄一个人坐在防波堤上。 +9. 带着珍珠项链的女子看向画面右侧并说着话。 +请直接输出改写后的文本,不要进行多余的回复。''' + + +I2V_A14B_EN_SYS_PROMPT = \ +'''You are an expert in rewriting video description prompts. Your task is to rewrite the provided video description prompts based on the images given by users, emphasizing potential dynamic content. Specific requirements are as follows: +The user's input language may include diverse descriptions, such as markdown format, instruction format, or be too long or too short. You need to extract the relevant information from the user’s input and associate it with the image content. +Your rewritten video description should retain the dynamic parts of the provided prompts, focusing on the main subject's actions. Emphasize and simplify the main subject of the image while retaining their movement. If the user only provides an action (e.g., "dancing"), supplement it reasonably based on the image content (e.g., "a girl is dancing"). +If the user’s input prompt is too long, refine it to capture the essential action process. If the input is too short, add reasonable motion-related details based on the image content. +Retain and emphasize descriptions of camera movements, such as "the camera pans up," "the camera moves from left to right," or "the camera moves from right to left." For example: "The camera captures two men fighting. They start lying on the ground, then the camera moves upward as they stand up. The camera shifts left, showing the man on the left holding a blue object while the man on the right tries to grab it, resulting in a fierce back-and-forth struggle." +Focus on dynamic content in the video description and avoid adding static scene descriptions. If the user’s input already describes elements visible in the image, remove those static descriptions. +Limit the rewritten prompt to 100 words or less. Regardless of the input language, your output must be in English. + +Examples of rewritten prompts: +The camera pulls back to show two foreign men walking up the stairs. The man on the left supports the man on the right with his right hand. +A black squirrel focuses on eating, occasionally looking around. +A man talks, his expression shifting from smiling to closing his eyes, reopening them, and finally smiling with closed eyes. His gestures are lively, making various hand motions while speaking. +A close-up of someone measuring with a ruler and pen, drawing a straight line on paper with a black marker in their right hand. +A model car moves on a wooden board, traveling from right to left across grass and wooden structures. +The camera moves left, then pushes forward to capture a person sitting on a breakwater. +A man speaks, his expressions and gestures changing with the conversation, while the overall scene remains constant. +The camera moves left, then pushes forward to capture a person sitting on a breakwater. +A woman wearing a pearl necklace looks to the right and speaks. +Output only the rewritten text without additional responses.''' + + +I2V_A14B_EMPTY_ZH_SYS_PROMPT = \ +'''你是一个视频描述提示词的撰写专家,你的任务是根据用户给你输入的图像,发挥合理的想象,让这张图动起来,你要强调潜在的动态内容。具体要求如下 +你需要根据图片的内容想象出运动的主体 +你输出的结果应强调图片中的动态部分,保留主体的动作。 +你需要给出对视频描述的动态内容,不要有过多的对于静态场景的描述 +输出的prompt字数控制在100字以下 +你需要输出中文 +prompt 示例: +1. 镜头后拉,拍摄两个外国男人,走在楼梯上,镜头左侧的男人右手搀扶着镜头右侧的男人。 +2. 一只黑色的小松鼠专注地吃着东西,偶尔抬头看看四周。 +3. 男子说着话,表情从微笑逐渐转变为闭眼,然后睁开眼睛,最后是闭眼微笑,他的手势活跃,在说话时做出一系列的手势。 +4. 一个人正在用尺子和笔进行测量的特写,右手用一支黑色水性笔在纸上画出一条直线。 +5. 一辆车模型在木板上形式,车辆从画面的右侧向左侧移动,经过一片草地和一些木制结构。 +6. 镜头左移后前推,拍摄一个人坐在防波堤上。 +7. 男子说着话,他的表情和手势随着对话内容的变化而变化,但整体场景保持不变。 +8. 镜头左移后前推,拍摄一个人坐在防波堤上。 +9. 带着珍珠项链的女子看向画面右侧并说着话。 +请直接输出文本,不要进行多余的回复。''' + + +I2V_A14B_EMPTY_EN_SYS_PROMPT = \ +'''You are an expert in writing video description prompts. Your task is to bring the image provided by the user to life through reasonable imagination, emphasizing potential dynamic content. Specific requirements are as follows: + +You need to imagine the moving subject based on the content of the image. +Your output should emphasize the dynamic parts of the image and retain the main subject’s actions. +Focus only on describing dynamic content; avoid excessive descriptions of static scenes. +Limit the output prompt to 100 words or less. +The output must be in English. + +Prompt examples: + +The camera pulls back to show two foreign men walking up the stairs. The man on the left supports the man on the right with his right hand. +A black squirrel focuses on eating, occasionally looking around. +A man talks, his expression shifting from smiling to closing his eyes, reopening them, and finally smiling with closed eyes. His gestures are lively, making various hand motions while speaking. +A close-up of someone measuring with a ruler and pen, drawing a straight line on paper with a black marker in their right hand. +A model car moves on a wooden board, traveling from right to left across grass and wooden structures. +The camera moves left, then pushes forward to capture a person sitting on a breakwater. +A man speaks, his expressions and gestures changing with the conversation, while the overall scene remains constant. +The camera moves left, then pushes forward to capture a person sitting on a breakwater. +A woman wearing a pearl necklace looks to the right and speaks. +Output only the text without additional responses.''' diff --git a/MindIE/MultiModal/Wan2.2/wan/utils/utils.py b/MindIE/MultiModal/Wan2.2/wan/utils/utils.py new file mode 100644 index 0000000000..c563c69817 --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/utils/utils.py @@ -0,0 +1,159 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +import binascii +import logging +import os +import os.path as osp + +import imageio +import torch +import torchvision + +__all__ = ['save_video', 'save_image', 'str2bool'] + + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + + +def save_video(tensor, + save_file=None, + fps=30, + suffix='.mp4', + nrow=8, + normalize=True, + value_range=(-1, 1)): + # cache file + cache_file = osp.join('/tmp', rand_name( + suffix=suffix)) if save_file is None else save_file + + # save to cache + try: + # preprocess + tensor = tensor.clamp(min(value_range), max(value_range)) + tensor = torch.stack([ + torchvision.utils.make_grid( + u, nrow=nrow, normalize=normalize, value_range=value_range) + for u in tensor.unbind(2) + ], + dim=1).permute(1, 2, 3, 0) + tensor = (tensor * 255).type(torch.uint8).cpu() + + # write video + writer = imageio.get_writer( + cache_file, fps=fps, codec='libx264', quality=8) + for frame in tensor.numpy(): + writer.append_data(frame) + writer.close() + except Exception as e: + logging.info(f'save_video failed, error: {e}') + + +def save_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1)): + # cache file + suffix = osp.splitext(save_file)[1] + if suffix.lower() not in [ + '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' + ]: + suffix = '.png' + + # save to cache + try: + tensor = tensor.clamp(min(value_range), max(value_range)) + torchvision.utils.save_image( + tensor, + save_file, + nrow=nrow, + normalize=normalize, + value_range=value_range) + return save_file + except Exception as e: + logging.info(f'save_image failed, error: {e}') + + +def str2bool(v): + """ + Convert a string to a boolean. + + Supported true values: 'yes', 'true', 't', 'y', '1' + Supported false values: 'no', 'false', 'f', 'n', '0' + + Args: + v (str): String to convert. + + Returns: + bool: Converted boolean value. + + Raises: + argparse.ArgumentTypeError: If the value cannot be converted to boolean. + """ + if isinstance(v, bool): + return v + v_lower = v.lower() + if v_lower in ('yes', 'true', 't', 'y', '1'): + return True + elif v_lower in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected (True/False)') + + +def masks_like(tensor, zero=False, generator=None, p=0.2): + assert isinstance(tensor, list) + out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] + + out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] + + if zero: + if generator is not None: + for u, v in zip(out1, out2): + random_num = torch.rand( + 1, generator=generator, device=generator.device).item() + if random_num < p: + u[:, 0] = torch.normal( + mean=-3.5, + std=0.5, + size=(1,), + device=u.device, + generator=generator).expand_as(u[:, 0]).exp() + v[:, 0] = torch.zeros_like(v[:, 0]) + else: + u[:, 0] = u[:, 0] + v[:, 0] = v[:, 0] + else: + for u, v in zip(out1, out2): + u[:, 0] = torch.zeros_like(u[:, 0]) + v[:, 0] = torch.zeros_like(v[:, 0]) + + return out1, out2 + + +def best_output_size(w, h, dw, dh, expected_area): + # float output size + ratio = w / h + ow = (expected_area * ratio)**0.5 + oh = expected_area / ow + + # process width first + ow1 = int(ow // dw * dw) + oh1 = int(expected_area / ow1 // dh * dh) + assert ow1 % dw == 0 and oh1 % dh == 0 and ow1 * oh1 <= expected_area + ratio1 = ow1 / oh1 + + # process height first + oh2 = int(oh // dh * dh) + ow2 = int(expected_area / oh2 // dw * dw) + assert oh2 % dh == 0 and ow2 % dw == 0 and ow2 * oh2 <= expected_area + ratio2 = ow2 / oh2 + + # compare ratios + if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, + ratio2 / ratio): + return ow1, oh1 + else: + return ow2, oh2 diff --git a/MindIE/MultiModal/Wan2.2/wan/vae_patch_parallel.py b/MindIE/MultiModal/Wan2.2/wan/vae_patch_parallel.py new file mode 100644 index 0000000000..6f6664281f --- /dev/null +++ b/MindIE/MultiModal/Wan2.2/wan/vae_patch_parallel.py @@ -0,0 +1,737 @@ +import torch +import torch_npu +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from functools import reduce +import functools + +class Parallel_VAE_SP: + def __init__(self, h_split=1, w_split=1, all_pp_group_ranks=None, **kwargs): + """ + Initialize distributed parallel processing parameters + + Args: + h_split (int): Number of splits along height dimension + w_split (int): Number of splits along width dimension + world_size (int): Total number of processes (default: current world size) + """ + if all_pp_group_ranks is None: + all_pp_group_ranks = [list(range(0, dist.get_world_size()))] + all_pp_group_size = [ len(pp_group_ranks) for pp_group_ranks in all_pp_group_ranks] + for s in all_pp_group_size: + assert s == all_pp_group_size[0], ( f" every group size should be same") + + world_size = all_pp_group_size[0] # Get total process count [[1]][[6]] + + # Validate world_size matches grid dimensions + assert w_split * h_split == world_size, ( + f"world_size must be {w_split} * {h_split} = {w_split*h_split}, but got {world_size}" + ) + + self._creat_pp_group(all_pp_group_ranks) + # self.rank is the rank in current_pp_group + self.rank = dist.get_rank(self.current_pp_group) # Current process rank [[6]] + self.world_size = dist.get_world_size(self.current_pp_group) + self.w_split = w_split + self.h_split = h_split + + # Calculate grid coordinates + self.row_rank = self.rank // w_split # Row index (0 to w_split-1) [[6]] + self.col_rank = self.rank % w_split # Column index (0 to h_split-1) [[6]] + + # Create communication groups + self._create_group_by_row(h_split, w_split, all_pp_group_ranks) + self._create_group_by_col(h_split, w_split, all_pp_group_ranks) + self._row_col_to_global_rank() + + self.ori_conv3d = None + + # world a list of list + def _creat_pp_group(self, all_pp_group_ranks=None): + for pp_group_ranks in all_pp_group_ranks: + group = dist.new_group(ranks=pp_group_ranks) + if dist.get_rank() in pp_group_ranks: + self.current_pp_group = group + # current_pp_group_ranks is the global rank of the current_pp_group + # the reason of need it , is irend irecv need global rank + self.current_pp_group_ranks = pp_group_ranks + + + def _create_group_by_row(self, h_split, w_split, all_pp_group_ranks): + """Create process groups for row-wise communication""" + for pp_group_ranks in all_pp_group_ranks: + for r in range(h_split): + ranks_in_row = [] + for c in range(w_split): + global_rank = pp_group_ranks[r * w_split + c] + ranks_in_row.append(global_rank) + row_group = dist.new_group(ranks=ranks_in_row) + if r == self.row_rank and dist.get_rank() in pp_group_ranks: + self.row_group = row_group + + def _create_group_by_col(self, h_split, w_split, all_pp_group_ranks): + """Create process groups for column-wise communication""" + for pp_group_ranks in all_pp_group_ranks: + for c in range(self.w_split): + ranks_in_col = [] + for r in range(self.h_split): + global_rank = pp_group_ranks[r * self.w_split + c] + ranks_in_col.append(global_rank) + col_group = dist.new_group(ranks=ranks_in_col) + if c == self.col_rank and dist.get_rank() in pp_group_ranks: + self.col_group = col_group + + + def _row_col_to_global_rank(self): + # Create rank mappings for communication + self.row_to_global_rank = { + r: self.current_pp_group_ranks[ + r * self.w_split + self.col_rank + ] + for r in range(self.h_split) + } + self.col_to_global_rank = { + c: self.current_pp_group_ranks[ + self.row_rank * self.w_split + c + ] + for c in range(self.w_split) + } + + def __call__(self, x): + """Split input tensor across last two dimensions""" + x = x.chunk(self.w_split, dim=-1)[self.col_rank] + x = x.chunk(self.h_split, dim=-2)[self.row_rank] + return x + + def patch(self, x, return_lst = False): + """ + Partition input tensor into grid blocks and record partition shapes + + Args: + x (torch.Tensor): Input tensor with shape [b, c, t, h, w] + + Returns: + torch.Tensor: Local partition tensor for current process + """ + # Get input dimensions + height, width = x.shape[-2:] + + # Calculate base partition dimensions + base_patch_height = height // self.h_split + base_patch_width = width // self.w_split + remainder_height = height % self.h_split + remainder_width = width % self.w_split + + # Generate partitions + patches = [] + for r in range(self.h_split): + for c in range(self.w_split): + # Calculate current partition dimensions + patch_height = base_patch_height + (1 if r < remainder_height else 0) + patch_width = base_patch_width + (1 if c < remainder_width else 0) + + # Calculate partition boundaries + start_h = r * base_patch_height + min(r, remainder_height) + end_h = start_h + patch_height + start_w = c * base_patch_width + min(c, remainder_width) + end_w = start_w + patch_width + + # Extract partition + patch = x[..., start_h:end_h, start_w:end_w] + patches.append(patch.contiguous()) + + # Get local partition + local_patch = patches[self.rank] + + return patches if return_lst else local_patch + + def dispatch(self, local_patch): + """ + Reconstruct full tensor through two-stage all-gather + + Args: + local_patch (torch.Tensor): Local partition tensor + + Returns: + torch.Tensor: Reconstructed full tensor + """ + # First all-gather to collect partition shapes + local_shape = torch.tensor(local_patch.shape[-2:], + device=local_patch.device, dtype=torch.int32) + shape_list = [torch.empty(2, dtype=torch.int32, + device=local_patch.device) for _ in range(self.world_size)] + dist.all_gather(shape_list, local_shape, group=self.current_pp_group) + + all_shapes = [tuple(shape.tolist()) for shape in shape_list] + + # Calculate original dimensions + total_h = 0 + total_w = 0 + row_heights = {} # Track row heights + col_widths = {} # Track column widths + + for rank in range(self.world_size): + r_rank = rank // self.w_split + c_rank = rank % self.w_split + h_part, w_part = all_shapes[rank] + + # Record first occurrence of row height + if r_rank not in row_heights: + row_heights[r_rank] = h_part + # Record first occurrence of column width + if c_rank not in col_widths: + col_widths[c_rank] = w_part + + total_h = sum(row_heights.values()) + total_w = sum(col_widths.values()) + # TODO dispatch should be release to process the [B C W H] + # Prepare buffers for data gathering + batch_size, channels, time_steps = local_patch.shape[:3] + + gathered_data = [ + torch.empty( + (batch_size * channels * time_steps * h_part * w_part,), + device=local_patch.device, + dtype=local_patch.dtype + ) for h_part, w_part in all_shapes + ] + # 执行 all_gather,确保所有进程发送相同长度的一维数据(需保证 local_patch 展平后长度与 element_counts 一致) + dist.all_gather(gathered_data, local_patch.view(-1).clone(), group=self.current_pp_group) + + # 将一维数据重新调整为目标形状 + for i, (h_part, w_part) in enumerate(all_shapes): + gathered_data[i] = gathered_data[i].view(batch_size, channels, time_steps, h_part, w_part) + + # Reconstruct full tensor + full_tensor = torch.empty( + (batch_size, channels, time_steps, total_h, total_w), + device=local_patch.device, + dtype=local_patch.dtype + ) + + current_row = 0 + for r in range(self.h_split): + current_col = 0 + row_height = row_heights[r] + for c in range(self.w_split): + rank = r * self.w_split + c + h_part, w_part = all_shapes[rank] + + # Place partition in correct position + full_tensor[:, :, :, current_row:current_row+h_part, + current_col:current_col+w_part] = gathered_data[rank] + current_col += col_widths[c] + current_row += row_height + + return full_tensor + + def exchange_columns(self, local_patch, pad=None): + """ + Perform column-wise data exchange with adjacent processes + + Args: + local_patch (torch.Tensor): Local partition tensor + pad (bool): Whether to add zero-padding for edge processes + + Returns: + torch.Tensor: Tensor with exchanged column data + """ + send_ops = [] + recv_ops = [] + left_recv = None + right_recv = None + + if self.w_split > 1: + # Send/receive left column + if self.col_rank > 0: + prev_rank = self.col_to_global_rank[self.col_rank - 1] + left_col = local_patch[..., :, :1].contiguous() + left_recv = torch.empty_like(left_col) + send_ops.append(dist.P2POp(dist.isend, left_col, prev_rank, group=self.row_group)) + recv_ops.append(dist.P2POp(dist.irecv, left_recv, prev_rank, group=self.row_group)) + + # Send/receive right column + if self.col_rank < self.w_split - 1: + next_rank = self.col_to_global_rank[self.col_rank + 1] + right_col = local_patch[..., :, -1:].contiguous() + right_recv = torch.empty_like(right_col) + send_ops.append(dist.P2POp(dist.isend, right_col, next_rank, group=self.row_group)) + recv_ops.append(dist.P2POp(dist.irecv, right_recv, next_rank, group=self.row_group)) + + # Execute communication + reqs = dist.batch_isend_irecv(send_ops + recv_ops) + for req in reqs: + req.wait() + + # Handle padding for edge cases + if pad: + left_pad = torch.zeros_like(local_patch[..., :, :1]) if self.col_rank == 0 else left_recv + right_pad = torch.zeros_like(local_patch[..., :, -1:]) if self.col_rank == self.w_split - 1 else right_recv + return torch.cat([left_pad, local_patch, right_pad], dim=-1).contiguous() + else: + if self.w_split > 1: + if self.col_rank == 0: + return torch.cat([local_patch, right_recv], dim=-1).contiguous() + elif self.col_rank == self.w_split - 1: + return torch.cat([left_recv, local_patch], dim=-1).contiguous() + else: + return torch.cat([left_recv, local_patch, right_recv], dim=-1).contiguous() + else: + return local_patch + + def exchange_rows(self, local_patch, pad=None): + """ + Perform row-wise data exchange with adjacent processes + + Args: + local_patch (torch.Tensor): Local partition tensor + pad (bool): Whether to add zero-padding for edge processes + + Returns: + torch.Tensor: Tensor with exchanged row data + """ + send_ops = [] + recv_ops = [] + top_recv = None + bottom_recv = None + + if self.h_split > 1: + # Send/receive top row + if self.row_rank > 0: + prev_rank = self.row_to_global_rank[self.row_rank - 1] + top_row = local_patch[..., :1, :].contiguous() + top_recv = torch.empty_like(top_row) + send_ops.append(dist.P2POp(dist.isend, top_row, prev_rank, group=self.col_group)) + recv_ops.append(dist.P2POp(dist.irecv, top_recv, prev_rank, group=self.col_group)) + + # Send/receive bottom row + if self.row_rank < self.h_split - 1: + next_rank = self.row_to_global_rank[self.row_rank + 1] + bottom_row = local_patch[..., -1:, :].contiguous() + bottom_recv = torch.empty_like(bottom_row) + send_ops.append(dist.P2POp(dist.isend, bottom_row, next_rank, group=self.col_group)) + recv_ops.append(dist.P2POp(dist.irecv, bottom_recv, next_rank, group=self.col_group)) + + # Execute communication + reqs = dist.batch_isend_irecv(send_ops + recv_ops) + for req in reqs: + req.wait() + + # Handle padding for edge cases + if pad: + top_pad = torch.zeros_like(local_patch[..., :1, :]) if self.row_rank == 0 else top_recv + bottom_pad = torch.zeros_like(local_patch[..., -1:, :]) if self.row_rank == self.h_split - 1 else bottom_recv + return torch.cat([top_pad, local_patch, bottom_pad], dim=-2).contiguous() + else: + if self.h_split > 1: + if self.row_rank == 0: + return torch.cat([local_patch, bottom_recv], dim=-2).contiguous() + elif self.row_rank == self.h_split - 1: + return torch.cat([top_recv, local_patch], dim=-2).contiguous() + else: + return torch.cat([top_recv, local_patch, bottom_recv], dim=-2).contiguous() + else: + return local_patch + + def wraps_f_conv3d(self, f_conv3d=F.conv3d): + """ + Decorator to handle distributed 3D convolution with padding + + Args: + f_conv3d: Original convolution function + + Returns: + Wrapped convolution function with distributed padding handling + """ + self.ori_conv3d = f_conv3d + + def wrapped_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + # Process padding parameters + if isinstance(padding, int): + padding = (padding, padding, padding) + else: + padding = tuple(padding) + if len(padding) != 3: + raise ValueError("padding must be an int or a 3-element tuple") + + # Validate parameters + if padding[-1] not in {0, 1} or padding[-2] not in {0, 1}: + raise NotImplementedError("Only support padding[1]/padding[2] as 0 or 1") + if not all(s == 1 for s in (stride[-2:] if isinstance(stride, tuple) else (stride,))): + raise NotImplementedError("Only support stride=1 for dim H, W") + if not all(d == 1 for d in (dilation if isinstance(dilation, tuple) else (dilation,))): + raise NotImplementedError("Only support dilation=1") + + # Validate kernel size and padding relationship [[3]][[6]] + kernel_size = weight.shape[2:5] # Get kernel dimensions (depth, height, width) + if padding[1] * 2 + 1 != kernel_size[1] or padding[2] * 2 + 1 != kernel_size[2]: + raise ValueError( + f"3D Convolution requires: " + f"padding[1]*2+1 == kernel_size[1] and padding[2]*2+1 == kernel_size[2]. " + f"Got padding={padding}, kernel_size={kernel_size}" + ) + + # Handle row and column exchanges for padding + if padding[-2] == 1: + input = self.exchange_rows(input, pad=True) + if padding[-1] == 1: + input = self.exchange_columns(input, pad=True) + + # Call original convolution with adjusted padding + return self.ori_conv3d(input, weight, bias, stride=stride, padding=(padding[0],0,0), + dilation=1, groups=groups) + return wrapped_conv3d + + def wraps_f_conv2d(self, f_conv2d=F.conv2d): + """ + Decorator to handle distributed 2D convolution with padding + + Args: + f_conv2d: Original 2D convolution function + + Returns: + Wrapped 2D convolution function with distributed padding handling + """ + self.ori_conv2d = f_conv2d + + def wrapped_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + + # Handle stride parameter + if not isinstance(stride, tuple): + stride = (stride, stride) # Convert to tuple if not already + + if not all(s == 1 for s in stride): + # Dispatch input if any stride value is not 1 + input = self.dispatch(input.unsqueeze(2)).squeeze(2) + + # Dynamically calculate the split range + total_out_channels = weight.size(0) + base = total_out_channels // self.world_size + remainder = total_out_channels % self.world_size + + # Record the number of channels assigned to each device + channels_per_rank = [ + base + (1 if r < remainder else 0) for r in range(self.world_size) + ] + + # Current process channel range + start = sum(channels_per_rank[:self.rank]) + end = start + channels_per_rank[self.rank] + + weight_chunk = weight.narrow(0, start, end - start) + bias_chunk = bias.narrow(0, start, end - start) if bias is not None else None + + # Call original convolution with adjusted parameters + output = self.ori_conv2d( + input, weight_chunk, bias_chunk, stride, padding, dilation, groups) + + # On r-th NPU output [B, C/N_r, H, W] -> list of [B, C/N_r, H/h_split _i , W/w_split _i] for i = 0 ~ world size-1 + patches = self.patch(output, return_lst=True) + + # Construct the list of receiving shapes + # On i-th NPU [B, C/N_r, H/h_split _i , W/w_split _i] , for r = 0 ~ world size-1 + h_part, w_part = patches[self.rank].shape[-2:] + recv_shapes = [ + (output.shape[0], channels_per_rank[r], h_part, w_part) + for r in range(self.world_size) + ] + # Prepare buffers for all-to-all communication + gathered_outputs = [ + torch.empty(recv_shapes[r], dtype=output.dtype, device=output.device) + for r in range(self.world_size) + ] + + # Perform all-to-all communication to exchange data across processes + dist.all_to_all(gathered_outputs, patches, group=self.current_pp_group) + + # Concatenate gathered outputs along the channel dimension + full_output = torch.cat(gathered_outputs, dim=1) + + return full_output + + else: + + # Process padding parameters + if isinstance(padding, int): + padding = (padding, padding) + else: + padding = tuple(padding) + if len(padding) != 2: + raise ValueError("padding must be an int or a 2-element tuple") + + # Validate parameters + if padding[-1] not in {0, 1} or padding[-2] not in {0, 1}: + raise NotImplementedError("Only support padding values as 0 or 1") + if not (all(s == 1 for s in (stride if isinstance(stride, tuple) else (stride,))) and + all(d == 1 for d in (dilation if isinstance(dilation, tuple) else (dilation,)))): + raise NotImplementedError("Only support stride=1 and dilation=1") + + # Validate kernel size and padding relationship [[8]] + kernel_size = weight.shape[2:4] # Get kernel dimensions (height, width) + if padding[0] * 2 + 1 != kernel_size[0] or padding[1] * 2 + 1 != kernel_size[1]: + raise ValueError( + f"2D Convolution requires: " + f"padding[0]*2+1 == kernel_size[0] and padding[1]*2+1 == kernel_size[1]. " + f"Got padding={padding}, kernel_size={kernel_size}" + ) + + # Handle row and column exchanges for padding + if padding[-2] == 1: + input = self.exchange_rows(input, pad=True) + if padding[-1] == 1: + input = self.exchange_columns(input, pad=True) + + # Call original convolution with adjusted padding + return self.ori_conv2d( + input, weight, bias, + stride=1, + padding=0, + dilation=1, + groups=groups + ) + return wrapped_conv2d + + def wraps_f_interpolate(self, f_interpolate=F.interpolate): + """ + Decorator to handle distributed interpolation operations + + Args: + f_interpolate: Original interpolation function + + Returns: + Wrapped interpolation function with distributed handling + """ + self.ori_interpolate = f_interpolate + + def wrapped_interpolate(input, size=None, scale_factor=None, mode='nearest', + align_corners=None, recompute_scale_factor=None, antialias=False): + # Validate inputs + if not isinstance(input, torch.Tensor): + raise TypeError("Input must be a PyTorch Tensor.") + if scale_factor is None: + raise ValueError("scale_factor must be provided") + + spatial_dims = input.dim() - 2 + if isinstance(scale_factor, int): + scale_factor = (scale_factor,) * spatial_dims + if not isinstance(scale_factor, tuple) or len(scale_factor) != spatial_dims: + raise ValueError(f"scale_factor must be an int or a tuple of length {spatial_dims}") + if any(sf > 2 for sf in scale_factor): + raise ValueError("Scale factors must not exceed 2") + + # Handle supported modes without data exchange + if mode in {"nearest", 'area', 'nearest-exact'}: # + return self.ori_interpolate( + input=input, + size=None, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=None, + antialias=False + ) + else: + # Handle modes requiring data exchange + use_exchange_rows = scale_factor[-2] == 2 + use_exchange_columns = scale_factor[-1] == 2 + + # Perform data exchange + if use_exchange_columns: + input = self.exchange_columns(input, pad=False) + if use_exchange_rows: + input = self.exchange_rows(input, pad=False) + + # Perform interpolation + output = self.ori_interpolate( + input=input, + size=None, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=None, + antialias=False + ) + + # Slice excess data + if use_exchange_columns and self.w_split > 1: + if self.col_rank == 0: + output = output[..., :-2] + elif self.col_rank < self.w_split - 1: + output = output[..., 2:-2] + else: + output = output[..., 2:] + + if use_exchange_rows: + if self.row_rank == 0: + output = output[..., :-2, :] + elif self.row_rank < self.h_split - 1: + output = output[..., 2:-2, :] + else: + output = output[..., 2:, :] + return output + return wrapped_interpolate + + def wraps_fa(self, fa, layout="BNSD"): + """ + Decorator for attention functions with distributed key/value handling + + Args: + fa: Original attention function + layout (str): Tensor layout ('BNSD' or 'BSND') + + Returns: + Wrapped attention function with distributed key/value handling + """ + self.ori_fa = fa + self.layout = layout + + def wrapped_fa(q, k, v, *args, **kwargs): + # Validate layout + if self.layout not in {"BNSD", "BSND"}: + raise ValueError("Unsupported layout. Only 'BNSD' and 'BSND' are supported.") + + # Gather key shapes across processes + local_shape = torch.tensor(k.shape, device=k.device) + all_shapes = [torch.empty_like(local_shape) for _ in range(self.world_size)] + dist.all_gather(all_shapes, local_shape, group=self.current_pp_group) + all_shapes = [tuple(shape.tolist()) for shape in all_shapes] + + # Prepare buffers for full keys/values + gathered_k = [torch.empty(shape, dtype=k.dtype, device=k.device) for shape in all_shapes] + gathered_v = [torch.empty_like(k_tensor) for k_tensor in gathered_k] + + # Gather full keys and values + dist.all_gather(gathered_k, k.contiguous(), group=self.current_pp_group) + dist.all_gather(gathered_v, v.contiguous(), group=self.current_pp_group) + + # Concatenate along sequence dimension + if layout == "BNSD": + full_k = torch.cat(gathered_k, dim=2) + full_v = torch.cat(gathered_v, dim=2) + else: + full_k = torch.cat(gathered_k, dim=1) + full_v = torch.cat(gathered_v, dim=1) + + # Call original attention function + return self.ori_fa(q, full_k, full_v, *args, **kwargs) + return wrapped_fa + + def wraps_decoder_fw(self, decoder_fw): + def wrapped_decoder_fw(input, *args,**kwargs): + input = self.patch(input) + output = decoder_fw(input, *args,**kwargs) + return self.dispatch(output) + return wrapped_decoder_fw + + def wraps_f_pad(self, f_pad=F.pad): + self.ori_pad = f_pad + def wrapped_pad(input, pad, mode='constant', value=None): + len_pad = len(pad) + if len_pad % 2 != 0: + raise ValueError("Padding length must be even-valued") + adapted_pad = list(pad) + if len_pad >1: + # Handle horizontal direction (left/right) + if self.w_split == 1: + # Apply full left/right padding when single slice + adapted_pad[0] = pad[0] + adapted_pad[1] = pad[1] + else: + # Apply pad[0], pad[1] to the left and right boundary + if self.col_rank == 0: + adapted_pad[0] = pad[0] + adapted_pad[1] = 0 + elif self.col_rank == self.w_split - 1: + adapted_pad[0] = 0 + adapted_pad[1] = pad[1] + else: + adapted_pad[0] = 0 + adapted_pad[1] = 0 + if len_pad > 3: + # Handle vertical direction (top/bottom) + if self.h_split == 1: + # Apply full top/bottom padding when single slice + adapted_pad[2] = pad[2] + adapted_pad[3] = pad[3] + else: + # Apply pad[2], pad[3] to the top and bottom boundary + if self.row_rank == 0: + adapted_pad[2] = pad[2] + adapted_pad[3] = 0 + elif self.row_rank == self.h_split - 1: + adapted_pad[2] = 0 + adapted_pad[3] = pad[3] + else: + adapted_pad[2] = 0 + adapted_pad[3] = 0 + + return self.ori_pad(input, tuple(adapted_pad), mode=mode, value=value) + return wrapped_pad + +VAE_PATCH_PARALLEL = None +FA_LAYOUT = None + +def set_vae_patch_parallel(vae,h_split=1, w_split=1, fa_layout="BNSD",decoder_decode="decoder.forward", + all_pp_group_ranks=None, **kwargs): + global VAE_PATCH_PARALLEL + global FA_LAYOUT + if VAE_PATCH_PARALLEL is None: + VAE_PATCH_PARALLEL = Parallel_VAE_SP(h_split, w_split, all_pp_group_ranks) + FA_LAYOUT = fa_layout + + # wraps_decoder_fw + decoder_decode_lst = decoder_decode.split(".") + # the function + ori_decoder_decode_func = reduce(getattr, decoder_decode_lst, vae) + # the name of the function + decoder_decode_func = decoder_decode_lst.pop() + ori_vae_decoder = reduce(getattr, decoder_decode_lst, vae) + + new_decoder_decode = VAE_PATCH_PARALLEL.wraps_decoder_fw(ori_decoder_decode_func) + setattr(ori_vae_decoder, decoder_decode_func, new_decoder_decode) + return ori_decoder_decode_func + +def get_vae_patch_parallel(): + return VAE_PATCH_PARALLEL + +class VAE_patch_parallel: + def __init__(self): + global VAE_PATCH_PARALLEL + self.vae_pp_cls = VAE_PATCH_PARALLEL + def __enter__(self): + if self.vae_pp_cls is not None: + self._sub_F_func() + self._sub_FA() + + def __exit__(self,t,v,trace): + if self.vae_pp_cls is not None: + self._revert_F_func() + self._revert_FA() + + def _sub_F_func(self): + F.conv3d = self.vae_pp_cls.wraps_f_conv3d(F.conv3d) + F.conv2d = self.vae_pp_cls.wraps_f_conv2d(F.conv2d) + F.interpolate = self.vae_pp_cls.wraps_f_interpolate(F.interpolate) + F.pad = self.vae_pp_cls.wraps_f_pad(F.pad) + + def _sub_FA(self): + global FA_LAYOUT + F.scaled_dot_product_attention = self.vae_pp_cls.wraps_fa( + F.scaled_dot_product_attention, layout=FA_LAYOUT) + + def _revert_F_func(self): + """Restore original PyTorch functions after context exit""" + if self.vae_pp_cls.ori_conv3d is not None: + F.conv3d = self.vae_pp_cls.ori_conv3d + if self.vae_pp_cls.ori_conv2d is not None: + F.conv2d = self.vae_pp_cls.ori_conv2d + if self.vae_pp_cls.ori_interpolate is not None: + F.interpolate = self.vae_pp_cls.ori_interpolate + if self.vae_pp_cls.ori_pad is not None: + F.pad = self.vae_pp_cls.ori_pad + + def _revert_FA(self): + """Restore original attention function after context exit""" + if self.vae_pp_cls.ori_fa is not None: + F.scaled_dot_product_attention = self.vae_pp_cls.ori_fa \ No newline at end of file -- Gitee