+
+# Hunyuan-DiT : A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding
+
+
+
+
+
+
+
+
+
+
+
+-----
+
+This repo contains PyTorch model definitions, pre-trained weights and inference/sampling code for our paper exploring Hunyuan-DiT. You can find more visualizations on our [project page](https://dit.hunyuan.tencent.com/).
+
+> [**Hunyuan-DiT: A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding**](https://arxiv.org/abs/2405.08748)
+
+> [**DialogGen: Multi-modal Interactive Dialogue System for Multi-turn Text-to-Image Generation**](https://arxiv.org/abs/2403.08857)
+
+## 🔥🔥🔥 News!!
+* Jun 27, 2024: :art: Hunyuan-Captioner is released, providing fine-grained caption for training data. See [mllm](./mllm) for details.
+* Jun 27, 2024: :tada: Support LoRa and ControlNet in diffusers. See [diffusers](./diffusers) for details.
+* Jun 27, 2024: :tada: 6GB GPU VRAM Inference scripts are released. See [lite](./lite) for details.
+* Jun 19, 2024: :tada: ControlNet is released, supporting canny, pose and depth control. See [training/inference codes](#controlnet) for details.
+* Jun 13, 2024: :zap: HYDiT-v1.1 version is released, which mitigates the issue of image oversaturation and alleviates the watermark issue. Please check [HunyuanDiT-v1.1 ](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1) and
+[Distillation-v1.1](https://huggingface.co/Tencent-Hunyuan/Distillation-v1.1) for more details.
+* Jun 13, 2024: :truck: The training code is released, offering [full-parameter training](#full-parameter-training) and [LoRA training](#lora).
+* Jun 06, 2024: :tada: Hunyuan-DiT is now available in ComfyUI. Please check [ComfyUI](#using-comfyui) for more details.
+* Jun 06, 2024: 🚀 We introduce Distillation version for Hunyuan-DiT acceleration, which achieves **50%** acceleration on NVIDIA GPUs. Please check [Distillation](https://huggingface.co/Tencent-Hunyuan/Distillation) for more details.
+* Jun 05, 2024: 🤗 Hunyuan-DiT is now available in 🤗 Diffusers! Please check the [example](#using--diffusers) below.
+* Jun 04, 2024: :globe_with_meridians: Support Tencent Cloud links to download the pretrained models! Please check the [links](#-download-pretrained-models) below.
+* May 22, 2024: 🚀 We introduce TensorRT version for Hunyuan-DiT acceleration, which achieves **47%** acceleration on NVIDIA GPUs. Please check [TensorRT-libs](https://huggingface.co/Tencent-Hunyuan/TensorRT-libs) for instructions.
+* May 22, 2024: 💬 We support demo running multi-turn text2image generation now. Please check the [script](#using-gradio) below.
+
+## 🤖 Try it on the web
+
+Welcome to our web-based [**Tencent Hunyuan Bot**](https://hunyuan.tencent.com/bot/chat), where you can explore our innovative products! Just input the suggested prompts below or any other **imaginative prompts containing drawing-related keywords** to activate the Hunyuan text-to-image generation feature. Unleash your creativity and create any picture you desire, **all for free!**
+
+You can use simple prompts similar to natural language text
+
+> 画一只穿着西装的猪
+>
+> draw a pig in a suit
+>
+> 生成一幅画,赛博朋克风,跑车
+>
+> generate a painting, cyberpunk style, sports car
+
+or multi-turn language interactions to create the picture.
+
+> 画一个木制的鸟
+>
+> draw a wooden bird
+>
+> 变成玻璃的
+>
+> turn into glass
+
+## 📑 Open-source Plan
+
+- Hunyuan-DiT (Text-to-Image Model)
+ - [x] Inference
+ - [x] Checkpoints
+ - [x] Distillation Version
+ - [x] TensorRT Version
+ - [x] Training
+ - [x] Lora
+ - [x] Controlnet (Pose, Canny, Depth)
+ - [x] Hunyuan-Captioner (Re-caption the raw image-text pairs)
+ - [x] 6GB GPU VRAM Inference
+ - [ ] IP-adapter
+ - [ ] Hunyuan-DiT-S checkpoints (0.7B model)
+- Mllm
+ - Hunyuan-Captioner
+ - [x] Inference
+ - [Hunyuan-DialogGen](https://github.com/Centaurusalpha/DialogGen) (Prompt Enhancement Model)
+ - [x] Inference
+- [X] Web Demo (Gradio)
+- [x] Multi-turn T2I Demo (Gradio)
+- [X] Cli Demo
+- [X] ComfyUI
+- [X] Diffusers
+- [ ] Kohya
+- [ ] WebUI
+
+
+## Contents
+- [Hunyuan-DiT](#hunyuan-dit--a-powerful-multi-resolution-diffusion-transformer-with-fine-grained-chinese-understanding)
+ - [Abstract](#abstract)
+ - [🎉 Hunyuan-DiT Key Features](#-hunyuan-dit-key-features)
+ - [Chinese-English Bilingual DiT Architecture](#chinese-english-bilingual-dit-architecture)
+ - [Multi-turn Text2Image Generation](#multi-turn-text2image-generation)
+ - [📈 Comparisons](#-comparisons)
+ - [🎥 Visualization](#-visualization)
+ - [📜 Requirements](#-requirements)
+ - [🛠 Dependencies and Installation](#%EF%B8%8F-dependencies-and-installation)
+ - [🧱 Download Pretrained Models](#-download-pretrained-models)
+ - [:truck: Training](#truck-training)
+ - [Data Preparation](#data-preparation)
+ - [Full Parameter Training](#full-parameter-training)
+ - [LoRA](#lora)
+ - [🔑 Inference](#-inference)
+ - [6GB GPU VRAM Inference](#6gb-gpu-vram-inference)
+ - [Using Gradio](#using-gradio)
+ - [Using Diffusers](#using--diffusers)
+ - [Using Command Line](#using-command-line)
+ - [More Configurations](#more-configurations)
+ - [Using ComfyUI](#using-comfyui)
+ - [:building_construction: Adatper](#building_construction-adapter)
+ - [ControlNet](#controlnet)
+ - [:art: Hunyuan-Captioner](#art-hunyuan-captioner)
+ - [🚀 Acceleration (for Linux)](#-acceleration-for-linux)
+ - [🔗 BibTeX](#-bibtex)
+
+## **Abstract**
+
+We present Hunyuan-DiT, a text-to-image diffusion transformer with fine-grained understanding of both English and Chinese. To construct Hunyuan-DiT, we carefully designed the transformer structure, text encoder, and positional encoding. We also build from scratch a whole data pipeline to update and evaluate data for iterative model optimization. For fine-grained language understanding, we train a Multimodal Large Language Model to refine the captions of the images. Finally, Hunyuan-DiT can perform multi-round multi-modal dialogue with users, generating and refining images according to the context.
+Through our carefully designed holistic human evaluation protocol with more than 50 professional human evaluators, Hunyuan-DiT sets a new state-of-the-art in Chinese-to-image generation compared with other open-source models.
+
+
+## 🎉 **Hunyuan-DiT Key Features**
+### **Chinese-English Bilingual DiT Architecture**
+Hunyuan-DiT is a diffusion model in the latent space, as depicted in figure below. Following the Latent Diffusion Model, we use a pre-trained Variational Autoencoder (VAE) to compress the images into low-dimensional latent spaces and train a diffusion model to learn the data distribution with diffusion models. Our diffusion model is parameterized with a transformer. To encode the text prompts, we leverage a combination of pre-trained bilingual (English and Chinese) CLIP and multilingual T5 encoder.
+
+
+
+
+### Multi-turn Text2Image Generation
+Understanding natural language instructions and performing multi-turn interaction with users are important for a
+text-to-image system. It can help build a dynamic and iterative creation process that bring the user’s idea into reality
+step by step. In this section, we will detail how we empower Hunyuan-DiT with the ability to perform multi-round
+conversations and image generation. We train MLLM to understand the multi-round user dialogue
+and output the new text prompt for image generation.
+
+
+
+
+## 📈 Comparisons
+In order to comprehensively compare the generation capabilities of HunyuanDiT and other models, we constructed a 4-dimensional test set, including Text-Image Consistency, Excluding AI Artifacts, Subject Clarity, Aesthetic. More than 50 professional evaluators performs the evaluation.
+
+
+
+* **Multi-turn Text2Image Generation**
+
+https://github.com/Tencent/tencent.github.io/assets/27557933/94b4dcc3-104d-44e1-8bb2-dc55108763d1
+
+
+
+---
+
+## 📜 Requirements
+
+This repo consists of DialogGen (a prompt enhancement model) and Hunyuan-DiT (a text-to-image model).
+
+The following table shows the requirements for running the models (batch size = 1):
+
+| Model | --load-4bit (DialogGen) | GPU Peak Memory | GPU |
+|:-----------------------:|:-----------------------:|:---------------:|:---------------:|
+| DialogGen + Hunyuan-DiT | ✘ | 32G | A100 |
+| DialogGen + Hunyuan-DiT | ✔ | 22G | A100 |
+| Hunyuan-DiT | - | 11G | A100 |
+| Hunyuan-DiT | - | 14G | RTX3090/RTX4090 |
+
+* An NVIDIA GPU with CUDA support is required.
+ * We have tested V100 and A100 GPUs.
+ * **Minimum**: The minimum GPU memory required is 11GB.
+ * **Recommended**: We recommend using a GPU with 32GB of memory for better generation quality.
+* Tested operating system: Linux
+
+## 🛠️ Dependencies and Installation
+
+Begin by cloning the repository:
+```shell
+git clone https://github.com/tencent/HunyuanDiT
+cd HunyuanDiT
+```
+
+### Installation Guide for Linux
+
+We provide an `environment.yml` file for setting up a Conda environment.
+Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html).
+
+We recommend CUDA versions 11.7 and 12.0+.
+
+```shell
+# 1. Prepare conda environment
+conda env create -f environment.yml
+
+# 2. Activate the environment
+conda activate HunyuanDiT
+
+# 3. Install pip dependencies
+python -m pip install -r requirements.txt
+
+# 4. (Optional) Install flash attention v2 for acceleration (requires CUDA 11.6 or above)
+python -m pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.1.2.post3
+```
+
+## 🧱 Download Pretrained Models
+To download the model, first install the huggingface-cli. (Detailed instructions are available [here](https://huggingface.co/docs/huggingface_hub/guides/cli).)
+
+```shell
+python -m pip install "huggingface_hub[cli]"
+```
+
+Then download the model using the following commands:
+
+```shell
+# Create a directory named 'ckpts' where the model will be saved, fulfilling the prerequisites for running the demo.
+mkdir ckpts
+# Use the huggingface-cli tool to download the model.
+# The download time may vary from 10 minutes to 1 hour depending on network conditions.
+huggingface-cli download Tencent-Hunyuan/HunyuanDiT --local-dir ./ckpts
+```
+
+
+💡Tips for using huggingface-cli (network problem)
+
+##### 1. Using HF-Mirror
+
+If you encounter slow download speeds in China, you can try a mirror to speed up the download process. For example,
+
+```shell
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download Tencent-Hunyuan/HunyuanDiT --local-dir ./ckpts
+```
+
+##### 2. Resume Download
+
+`huggingface-cli` supports resuming downloads. If the download is interrupted, you can just rerun the download
+command to resume the download process.
+
+Note: If an `No such file or directory: 'ckpts/.huggingface/.gitignore.lock'` like error occurs during the download
+process, you can ignore the error and rerun the download command.
+
+
+
+---
+
+All models will be automatically downloaded. For more information about the model, visit the Hugging Face repository [here](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT).
+
+| Model | #Params | Huggingface Download URL | Tencent Cloud Download URL |
+|:------------------:|:-------:|:-------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|
+| mT5 | 1.6B | [mT5](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/t2i/mt5) | [mT5](https://dit.hunyuan.tencent.com/download/HunyuanDiT/mt5.zip) |
+| CLIP | 350M | [CLIP](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/t2i/clip_text_encoder) | [CLIP](https://dit.hunyuan.tencent.com/download/HunyuanDiT/clip_text_encoder.zip) |
+| Tokenizer | - | [Tokenizer](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/t2i/tokenizer) | [Tokenizer](https://dit.hunyuan.tencent.com/download/HunyuanDiT/tokenizer.zip) |
+| DialogGen | 7.0B | [DialogGen](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/dialoggen) | [DialogGen](https://dit.hunyuan.tencent.com/download/HunyuanDiT/dialoggen.zip) |
+| sdxl-vae-fp16-fix | 83M | [sdxl-vae-fp16-fix](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/t2i/sdxl-vae-fp16-fix) | [sdxl-vae-fp16-fix](https://dit.hunyuan.tencent.com/download/HunyuanDiT/sdxl-vae-fp16-fix.zip) |
+| Hunyuan-DiT-v1.0 | 1.5B | [Hunyuan-DiT](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/t2i/model) | [Hunyuan-DiT-v1.0](https://dit.hunyuan.tencent.com/download/HunyuanDiT/model.zip) |
+| Hunyuan-DiT-v1.1 | 1.5B | [Hunyuan-DiT-v1.1](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1/tree/main/t2i/model) | [Hunyuan-DiT-v1.1](https://dit.hunyuan.tencent.com/download/HunyuanDiT/model-v1_1.zip) |
+| Data demo | - | - | [Data demo](https://dit.hunyuan.tencent.com/download/HunyuanDiT/data_demo.zip) |
+
+## :truck: Training
+
+### Data Preparation
+
+ Refer to the commands below to prepare the training data.
+
+ 1. Install dependencies
+
+ We offer an efficient data management library, named IndexKits, supporting the management of reading hundreds of millions of data during training, see more in [docs](./IndexKits/README.md).
+ ```shell
+ # 1 Install dependencies
+ cd HunyuanDiT
+ pip install -e ./IndexKits
+ ```
+ 2. Data download
+
+ Feel free to download the [data demo](https://dit.hunyuan.tencent.com/download/HunyuanDiT/data_demo.zip).
+ ```shell
+ # 2 Data download
+ wget -O ./dataset/data_demo.zip https://dit.hunyuan.tencent.com/download/HunyuanDiT/data_demo.zip
+ unzip ./dataset/data_demo.zip -d ./dataset
+ mkdir ./dataset/porcelain/arrows ./dataset/porcelain/jsons
+ ```
+ 3. Data conversion
+
+ Create a CSV file for training data with the fields listed in the table below.
+
+ | Fields | Required | Description | Example |
+ |:---------------:| :------: |:----------------:|:-----------:|
+ | `image_path` | Required | image path | `./dataset/porcelain/images/0.png` |
+ | `text_zh` | Required | text | 青花瓷风格,一只蓝色的鸟儿站在蓝色的花瓶上,周围点缀着白色花朵,背景是白色 |
+ | `md5` | Optional | image md5 (Message Digest Algorithm 5) | `d41d8cd98f00b204e9800998ecf8427e` |
+ | `width` | Optional | image width | `1024 ` |
+ | `height` | Optional | image height | ` 1024 ` |
+
+ > ⚠️ Optional fields like MD5, width, and height can be omitted. If omitted, the script below will automatically calculate them. This process can be time-consuming when dealing with large-scale training data.
+
+ We utilize [Arrow](https://github.com/apache/arrow) for training data format, offering a standard and efficient in-memory data representation. A conversion script is provided to transform CSV files into Arrow format.
+ ```shell
+ # 3 Data conversion
+ python ./hydit/data_loader/csv2arrow.py ./dataset/porcelain/csvfile/image_text.csv ./dataset/porcelain/arrows
+ ```
+
+ 4. Data Selection and Configuration File Creation
+
+ We configure the training data through YAML files. In these files, you can set up standard data processing strategies for filtering, copying, deduplicating, and more regarding the training data. For more details, see [./IndexKits](IndexKits/docs/MakeDataset.md).
+
+ For a sample file, please refer to [file](./dataset/yamls/porcelain.yaml). For a full parameter configuration file, see [file](./IndexKits/docs/MakeDataset.md).
+
+
+ 5. Create training data index file using YAML file.
+
+ ```shell
+ # Single Resolution Data Preparation
+ idk base -c dataset/yamls/porcelain.yaml -t dataset/porcelain/jsons/porcelain.json
+
+ # Multi Resolution Data Preparation
+ idk multireso -c dataset/yamls/porcelain_mt.yaml -t dataset/porcelain/jsons/porcelain_mt.json
+ ```
+
+ The directory structure for `porcelain` dataset is:
+
+ ```shell
+ cd ./dataset
+
+ porcelain
+ ├──images/ (image files)
+ │ ├──0.png
+ │ ├──1.png
+ │ ├──......
+ ├──csvfile/ (csv files containing text-image pairs)
+ │ ├──image_text.csv
+ ├──arrows/ (arrow files containing all necessary training data)
+ │ ├──00000.arrow
+ │ ├──00001.arrow
+ │ ├──......
+ ├──jsons/ (final training data index files which read data from arrow files during training)
+ │ ├──porcelain.json
+ │ ├──porcelain_mt.json
+ ```
+
+### Full-parameter Training
+
+ To leverage DeepSpeed in training, you have the flexibility to control **single-node** / **multi-node** training by adjusting parameters such as `--hostfile` and `--master_addr`. For more details, see [link](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node).
+
+ ```shell
+ # Single Resolution Training
+ PYTHONPATH=./ sh hydit/train.sh --index-file dataset/porcelain/jsons/porcelain.json
+
+ # Multi Resolution Training
+ PYTHONPATH=./ sh hydit/train.sh --index-file dataset/porcelain/jsons/porcelain_mt.json --multireso --reso-step 64
+ ```
+
+### LoRA
+
+
+
+We provide training and inference scripts for LoRA, detailed in the [./lora](./lora/README.md).
+
+ ```shell
+ # Training for porcelain LoRA.
+ PYTHONPATH=./ sh lora/train_lora.sh --index-file dataset/porcelain/jsons/porcelain.json
+
+ # Inference using trained LORA weights.
+ python sample_t2i.py --prompt "青花瓷风格,一只小狗" --no-enhance --lora-ckpt log_EXP/001-lora_porcelain_ema_rank64/checkpoints/0001000.pt
+ ```
+ We offer two types of trained LoRA weights for `porcelain` and `jade`, see details at [links](https://huggingface.co/Tencent-Hunyuan/HYDiT-LoRA)
+ ```shell
+ cd HunyuanDiT
+ # Use the huggingface-cli tool to download the model.
+ huggingface-cli download Tencent-Hunyuan/HYDiT-LoRA --local-dir ./ckpts/t2i/lora
+
+ # Quick start
+ python sample_t2i.py --prompt "青花瓷风格,一只猫在追蝴蝶" --no-enhance --load-key ema --lora-ckpt ./ckpts/t2i/lora/porcelain
+ ```
+
+
+
Examples of training data
+
+
+
+
+
+
+
+
+
+
青花瓷风格,一只蓝色的鸟儿站在蓝色的花瓶上,周围点缀着白色花朵,背景是白色 (Porcelain style, a blue bird stands on a blue vase, surrounded by white flowers, with a white background.
+)
+
青花瓷风格,这是一幅蓝白相间的陶瓷盘子,上面描绘着一只狐狸和它的幼崽在森林中漫步,背景是白色 (Porcelain style, this is a blue and white ceramic plate depicting a fox and its cubs strolling in the forest, with a white background.)
+
青花瓷风格,在黑色背景上,一只蓝色的狼站在蓝白相间的盘子上,周围是树木和月亮 (Porcelain style, on a black background, a blue wolf stands on a blue and white plate, surrounded by trees and the moon.)
+
青花瓷风格,在蓝色背景上,一只蓝色蝴蝶和白色花朵被放置在中央 (Porcelain style, on a blue background, a blue butterfly and white flowers are placed in the center.)
+
+
+
Examples of inference results
+
+
+
+
+
+
+
+
+
青花瓷风格,苏州园林 (Porcelain style, Suzhou Gardens.)
+
青花瓷风格,一朵荷花 (Porcelain style, a lotus flower.)
+
青花瓷风格,一只羊(Porcelain style, a sheep.)
+
青花瓷风格,一个女孩在雨中跳舞(Porcelain style, a girl dancing in the rain.)
+
+
+
+
+
+## 🔑 Inference
+
+### 6GB GPU VRAM Inference
+Running HunyuanDiT in under 6GB GPU VRAM is available now based on [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuandit). Here we provide instructions and demo for your quick start.
+
+> The 6GB version supports Nvidia Ampere architecture series graphics cards such as RTX 3070/3080/4080/4090, A100, and so on.
+
+The only thing you need do is to install the following library:
+
+```bash
+pip install -U bitsandbytes
+pip install git+https://github.com/huggingface/diffusers
+pip install torch==2.0.0
+```
+
+Then you can enjoy your HunyuanDiT text-to-image journey under 6GB GPU VRAM directly!
+
+Here is a demo for you.
+
+```bash
+cd HunyuanDiT
+
+# Quick start
+model_id=Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled
+prompt=一个宇航员在骑马
+infer_steps=50
+guidance_scale=6
+python3 lite/inference.py ${model_id} ${prompt} ${infer_steps} ${guidance_scale}
+```
+
+More details can be found in [./lite](lite/README.md).
+
+
+### Using Gradio
+
+Make sure the conda environment is activated before running the following command.
+
+```shell
+# By default, we start a Chinese UI.
+python app/hydit_app.py
+
+# Using Flash Attention for acceleration.
+python app/hydit_app.py --infer-mode fa
+
+# You can disable the enhancement model if the GPU memory is insufficient.
+# The enhancement will be unavailable until you restart the app without the `--no-enhance` flag.
+python app/hydit_app.py --no-enhance
+
+# Start with English UI
+python app/hydit_app.py --lang en
+
+# Start a multi-turn T2I generation UI.
+# If your GPU memory is less than 32GB, use '--load-4bit' to enable 4-bit quantization, which requires at least 22GB of memory.
+python app/multiTurnT2I_app.py
+```
+Then the demo can be accessed through http://0.0.0.0:443. It should be noted that the 0.0.0.0 here needs to be X.X.X.X with your server IP.
+
+### Using 🤗 Diffusers
+
+Please install PyTorch version 2.0 or higher in advance to satisfy the requirements of the specified version of the diffusers library.
+
+Install 🤗 diffusers, ensuring that the version is at least 0.28.1:
+
+```shell
+pip install git+https://github.com/huggingface/diffusers.git
+```
+or
+```shell
+pip install diffusers
+```
+
+You can generate images with both Chinese and English prompts using the following Python script:
+```py
+import torch
+from diffusers import HunyuanDiTPipeline
+
+pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16)
+pipe.to("cuda")
+
+# You may also use English prompt as HunyuanDiT supports both English and Chinese
+# prompt = "An astronaut riding a horse"
+prompt = "一个宇航员在骑马"
+image = pipe(prompt).images[0]
+```
+You can use our distilled model to generate images even faster:
+
+```py
+import torch
+from diffusers import HunyuanDiTPipeline
+
+pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-Diffusers-Distilled", torch_dtype=torch.float16)
+pipe.to("cuda")
+
+# You may also use English prompt as HunyuanDiT supports both English and Chinese
+# prompt = "An astronaut riding a horse"
+prompt = "一个宇航员在骑马"
+image = pipe(prompt, num_inference_steps=25).images[0]
+```
+More details can be found in [HunyuanDiT-Diffusers-Distilled](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-Diffusers-Distilled)
+
+**More functions:** For other functions like LoRA and ControlNet, please have a look at the README of [./diffusers](diffusers).
+
+### Using Command Line
+
+We provide several commands to quick start:
+
+```shell
+# Prompt Enhancement + Text-to-Image. Torch mode
+python sample_t2i.py --prompt "渔舟唱晚"
+
+# Only Text-to-Image. Torch mode
+python sample_t2i.py --prompt "渔舟唱晚" --no-enhance
+
+# Only Text-to-Image. Flash Attention mode
+python sample_t2i.py --infer-mode fa --prompt "渔舟唱晚"
+
+# Generate an image with other image sizes.
+python sample_t2i.py --prompt "渔舟唱晚" --image-size 1280 768
+
+# Prompt Enhancement + Text-to-Image. DialogGen loads with 4-bit quantization, but it may loss performance.
+python sample_t2i.py --prompt "渔舟唱晚" --load-4bit
+
+```
+
+More example prompts can be found in [example_prompts.txt](example_prompts.txt)
+
+### More Configurations
+
+We list some more useful configurations for easy usage:
+
+| Argument | Default | Description |
+|:---------------:|:---------:|:---------------------------------------------------:|
+| `--prompt` | None | The text prompt for image generation |
+| `--image-size` | 1024 1024 | The size of the generated image |
+| `--seed` | 42 | The random seed for generating images |
+| `--infer-steps` | 100 | The number of steps for sampling |
+| `--negative` | - | The negative prompt for image generation |
+| `--infer-mode` | torch | The inference mode (torch, fa, or trt) |
+| `--sampler` | ddpm | The diffusion sampler (ddpm, ddim, or dpmms) |
+| `--no-enhance` | False | Disable the prompt enhancement model |
+| `--model-root` | ckpts | The root directory of the model checkpoints |
+| `--load-key` | ema | Load the student model or EMA model (ema or module) |
+| `--load-4bit` | Fasle | Load DialogGen model with 4bit quantization |
+
+### Using ComfyUI
+
+We provide several commands to quick start:
+
+```shell
+# Download comfyui code
+git clone https://github.com/comfyanonymous/ComfyUI.git
+
+# Install torch, torchvision, torchaudio
+pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117
+
+# Install Comfyui essential python package.
+cd ComfyUI
+pip install -r requirements.txt
+
+# ComfyUI has been successfully installed!
+
+# Download model weight as before or link the existing model folder to ComfyUI.
+python -m pip install "huggingface_hub[cli]"
+mkdir models/hunyuan
+huggingface-cli download Tencent-Hunyuan/HunyuanDiT --local-dir ./models/hunyuan/ckpts
+
+# Move to the ComfyUI custom_nodes folder and copy comfyui-hydit folder from HunyuanDiT Repo.
+cd custom_nodes
+cp -r ${HunyuanDiT}/comfyui-hydit ./
+cd comfyui-hydit
+
+# Install some essential python Package.
+pip install -r requirements.txt
+
+# Our tool has been successfully installed!
+
+# Go to ComfyUI main folder
+cd ../..
+# Run the ComfyUI Lauch command
+python main.py --listen --port 80
+
+# Running ComfyUI successfully!
+```
+More details can be found in [./comfyui-hydit](comfyui-hydit/README.md)
+
+## :building_construction: Adapter
+
+### ControlNet
+
+We provide training scripts for ControlNet, detailed in the [./controlnet](./controlnet/README.md).
+
+ ```shell
+ # Training for canny ControlNet.
+ PYTHONPATH=./ sh hydit/train_controlnet.sh
+ ```
+ We offer three types of trained ControlNet weights for `canny` `depth` and `pose`, see details at [links](https://huggingface.co/Tencent-Hunyuan/HYDiT-ControlNet)
+ ```shell
+ cd HunyuanDiT
+ # Use the huggingface-cli tool to download the model.
+ # We recommend using distilled weights as the base model for ControlNet inference, as our provided pretrained weights are trained on them.
+ huggingface-cli download Tencent-Hunyuan/HYDiT-ControlNet --local-dir ./ckpts/t2i/controlnet
+ huggingface-cli download Tencent-Hunyuan/Distillation-v1.1 ./pytorch_model_distill.pt --local-dir ./ckpts/t2i/model
+
+ # Quick start
+ python3 sample_controlnet.py --no-enhance --load-key distill --infer-steps 50 --control-type canny --prompt "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围" --condition-image-path controlnet/asset/input/canny.jpg --control-weight 1.0
+ ```
+
+
+
+
Condition Input
+
+
+
+
Canny ControlNet
+
Depth ControlNet
+
Pose ControlNet
+
+
+
+
在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围 (At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere.)
+
在茂密的森林中,一只黑白相间的熊猫静静地坐在绿树红花中,周围是山川和海洋。背景是白天的森林,光线充足 (In the dense forest, a black and white panda sits quietly in green trees and red flowers, surrounded by mountains, rivers, and the ocean. The background is the forest in a bright environment.)
+
一位亚洲女性,身穿绿色上衣,戴着紫色头巾和紫色围巾,站在黑板前。背景是黑板。照片采用近景、平视和居中构图的方式呈现真实摄影风格 (An Asian woman, dressed in a green top, wearing a purple headscarf and a purple scarf, stands in front of a blackboard. The background is the blackboard. The photo is presented in a close-up, eye-level, and centered composition, adopting a realistic photographic style)
+
+
+
+
+
+
+
+
+
+
+
ControlNet Output
+
+
+
+
+
+
+
+
+
+
+## :art: Hunyuan-Captioner
+Hunyuan-Captioner meets the need of text-to-image techniques by maintaining a high degree of image-text consistency. It can generate high-quality image descriptions from a variety of angles, including object description, objects relationships, background information, image style, etc. Our code is based on [LLaVA](https://github.com/haotian-liu/LLaVA) implementation.
+
+### Examples
+
+
+
+### Instructions
+a. Install dependencies
+
+The dependencies and installation are basically the same as the [**base model**](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1).
+
+b. Data download
+```shell
+cd HunyuanDiT
+wget -O ./dataset/data_demo.zip https://dit.hunyuan.tencent.com/download/HunyuanDiT/data_demo.zip
+unzip ./dataset/data_demo.zip -d ./dataset
+mkdir ./dataset/porcelain/arrows ./dataset/porcelain/jsons
+```
+
+c. Model download
+```shell
+# Use the huggingface-cli tool to download the model.
+huggingface-cli download Tencent-Hunyuan/HunyuanCaptioner --local-dir ./ckpts/captioner
+```
+
+### Inference
+
+Current supported prompt templates:
+
+|Mode | Prompt template |Description |
+| --- | --- | --- |
+|caption_zh | 描述这张图片 |Caption in Chinese |
+|insert_content | 根据提示词“{}”,描述这张图片 |Insert specific knowledge into caption|
+|caption_en | Please describe the content of this image |Caption in English |
+| | | |
+
+
+a. Single picture inference in Chinese
+
+```bash
+python mllm/caption_demo.py --mode "caption_zh" --image_file "mllm/images/demo1.png" --model_path "./ckpts/captioner"
+```
+
+b. Insert specific knowledge into caption
+
+```bash
+python mllm/caption_demo.py --mode "insert_content" --content "宫保鸡丁" --image_file "mllm/images/demo2.png" --model_path "./ckpts/captioner"
+```
+
+c. Single picture inference in English
+
+```bash
+python mllm/caption_demo.py --mode "caption_en" --image_file "mllm/images/demo3.png" --model_path "./ckpts/captioner"
+```
+
+d. Multiple pictures inference in Chinese
+
+```bash
+### Convert multiple pictures to csv file.
+python mllm/make_csv.py --img_dir "mllm/images" --input_file "mllm/images/demo.csv"
+
+### Multiple pictures inference
+python mllm/caption_demo.py --mode "caption_zh" --input_file "mllm/images/demo.csv" --output_file "mllm/images/demo_res.csv" --model_path "./ckpts/captioner"
+```
+
+(Optional) To convert the output csv file to Arrow format, please refer to [Data Preparation #3](#data-preparation) for detailed instructions.
+
+
+### Gradio
+To launch a Gradio demo locally, please run the following commands one by one. For more detailed instructions, please refer to [LLaVA](https://github.com/haotian-liu/LLaVA).
+```bash
+cd mllm
+python -m llava.serve.controller --host 0.0.0.0 --port 10000
+
+python -m llava.serve.gradio_web_server --controller http://0.0.0.0:10000 --model-list-mode reload --port 443
+
+python -m llava.serve.model_worker --host 0.0.0.0 --controller http://0.0.0.0:10000 --port 40000 --worker http://0.0.0.0:40000 --model-path "./ckpts/captioner" --model-name LlavaMistral
+```
+Then the demo can be accessed through http://0.0.0.0:443. It should be noted that the 0.0.0.0 here needs to be X.X.X.X with your server IP.
+
+## 🚀 Acceleration (for Linux)
+
+- We provide TensorRT version of HunyuanDiT for inference acceleration (faster than flash attention).
+See [Tencent-Hunyuan/TensorRT-libs](https://huggingface.co/Tencent-Hunyuan/TensorRT-libs) for more details.
+
+- We provide Distillation version of HunyuanDiT for inference acceleration.
+See [Tencent-Hunyuan/Distillation](https://huggingface.co/Tencent-Hunyuan/Distillation) for more details.
+
+## 🔗 BibTeX
+If you find [Hunyuan-DiT](https://arxiv.org/abs/2405.08748) or [DialogGen](https://arxiv.org/abs/2403.08857) useful for your research and applications, please cite using this BibTeX:
+
+```BibTeX
+@misc{li2024hunyuandit,
+ title={Hunyuan-DiT: A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding},
+ author={Zhimin Li and Jianwei Zhang and Qin Lin and Jiangfeng Xiong and Yanxin Long and Xinchi Deng and Yingfang Zhang and Xingchao Liu and Minbin Huang and Zedong Xiao and Dayou Chen and Jiajun He and Jiahao Li and Wenyue Li and Chen Zhang and Rongwei Quan and Jianxiang Lu and Jiabin Huang and Xiaoyan Yuan and Xiaoxiao Zheng and Yixuan Li and Jihong Zhang and Chao Zhang and Meng Chen and Jie Liu and Zheng Fang and Weiyan Wang and Jinbao Xue and Yangyu Tao and Jianchen Zhu and Kai Liu and Sihuan Lin and Yifu Sun and Yun Li and Dongdong Wang and Mingtao Chen and Zhichao Hu and Xiao Xiao and Yan Chen and Yuhong Liu and Wei Liu and Di Wang and Yong Yang and Jie Jiang and Qinglin Lu},
+ year={2024},
+ eprint={2405.08748},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+
+@article{huang2024dialoggen,
+ title={DialogGen: Multi-modal Interactive Dialogue System for Multi-turn Text-to-Image Generation},
+ author={Huang, Minbin and Long, Yanxin and Deng, Xinchi and Chu, Ruihang and Xiong, Jiangfeng and Liang, Xiaodan and Cheng, Hong and Lu, Qinglin and Liu, Wei},
+ journal={arXiv preprint arXiv:2403.08857},
+ year={2024}
+}
+```
+
+## Start History
+
+
+
+
+
+
+
+
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/app/hydit_app.py b/PyTorch/built-in/mlm/HunyuanDiT/app/hydit_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..695fcb3eea4dad50bdb28f6c553a4834c2499706
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/app/hydit_app.py
@@ -0,0 +1,169 @@
+import gradio as gr
+import pandas as pd
+from pathlib import Path
+from PIL import Image
+import sys
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from hydit.constants import SAMPLER_FACTORY
+from sample_t2i import inferencer
+
+ROOT = Path(__file__).parent.parent
+SAMPLERS = list(SAMPLER_FACTORY.keys())
+SIZES = {
+ "square": (1024, 1024),
+ "landscape": (768, 1280),
+ "portrait": (1280, 768),
+}
+
+def get_strings(lang):
+ lang_file = Path(f"app/lang/{lang}.csv")
+ strings = pd.read_csv(lang_file, header=0)
+ strings = strings.set_index("key")['value'].to_dict()
+ return strings
+
+
+args, gen, enhancer = inferencer()
+strings = get_strings(args.lang)
+
+
+def infer(
+ prompt,
+ negative_prompt,
+ seed,
+ cfg_scale,
+ infer_steps,
+ oriW, oriH,
+ sampler,
+ size,
+ enhance
+):
+ if enhance and enhancer is not None:
+ success, enhanced_prompt = enhancer(prompt)
+ if not success:
+ fail_image = Image.open(ROOT / 'app/fail.png')
+ return fail_image
+ else:
+ enhanced_prompt = None
+
+ height, width = SIZES[size]
+ results = gen.predict(prompt,
+ height=height,
+ width=width,
+ seed=seed,
+ enhanced_prompt=enhanced_prompt,
+ negative_prompt=negative_prompt,
+ infer_steps=infer_steps,
+ guidance_scale=cfg_scale,
+ batch_size=1,
+ src_size_cond=(oriW, oriH),
+ sampler=sampler,
+ )
+ image = results['images'][0]
+ return image
+
+
+def ui():
+ block = gr.Blocks()
+
+ description = f"""
+ # {strings['title']}
+
+ ## {strings['desc']}
+
+ """
+
+ with block:
+ with gr.Row():
+ gr.Markdown(description)
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ size = gr.Radio(
+ label=strings['size'], choices=[
+ (strings['square'], 'square'),
+ (strings['landscape'], 'landscape'),
+ (strings['portrait'], 'portrait'),
+ ],
+ value="square"
+ )
+ prompt = gr.Textbox(label=strings['prompt'], value=strings['default prompt'], lines=3)
+ with gr.Row():
+ infer_steps = gr.Slider(
+ label=strings['infer steps'], minimum=1, maximum=200, value=100, step=1,
+ )
+ seed = gr.Number(
+ label=strings['seed'], minimum=-1, maximum=1_000_000_000, value=42, step=1, precision=0,
+ )
+ enhance = gr.Checkbox(
+ label=strings['enhance'], value=enhancer is not None, interactive=True,
+ )
+
+ with gr.Accordion(
+ strings['accordion'], open=False
+ ):
+ with gr.Row():
+ negative_prompt = gr.Textbox(label=strings['negative_prompt'],
+ value=gen.default_negative_prompt,
+ lines=2,
+ )
+ with gr.Row():
+ sampler = gr.Dropdown(SAMPLERS, label=strings['sampler'], value="ddpm")
+ cfg_scale = gr.Slider(
+ label=strings['cfg'], minimum=1.0, maximum=16.0, value=6.0, step=1
+ )
+ oriW = gr.Number(
+ label=strings['width cond'], minimum=1024, maximum=4096, value=1024, step=64, precision=0,
+ min_width=80,
+ )
+ oriH = gr.Number(
+ label=strings['height cond'], minimum=1024, maximum=4096, value=1024, step=64, precision=0,
+ min_width=80,
+ )
+ with gr.Row():
+ advanced_button = gr.Button(strings['run'])
+ with gr.Column():
+ default_img = Image.open(ROOT / 'app/default.png')
+ output_img = gr.Image(
+ label=strings['generated image'],
+ interactive=False,
+ format='png',
+ value=default_img,
+ )
+ advanced_button.click(
+ fn=infer,
+ inputs=[
+ prompt, negative_prompt, seed, cfg_scale, infer_steps,
+ oriW, oriH, sampler, size, enhance,
+ ],
+ outputs=output_img,
+ )
+
+ with gr.Row():
+ gr.Examples([
+ ['一只小猫'],
+ ['现实主义风格,画面主要描述一个巴洛克风格的花瓶,带有金色的装饰边框,花瓶上盛开着各种色彩鲜艳的花,白色背景'],
+ ['一只聪明的狐狸走在阔叶树林里, 旁边是一条小溪, 细节真实, 摄影'],
+ ['飞流直下三千尺,疑是银河落九天'],
+ ['一只长靴猫手持亮银色的宝剑,身着铠甲,眼神坚毅,站在一堆金币上,背景是暗色调的洞穴,图像上有金币的光影点缀。'],
+ ['麻婆豆腐'],
+ ['苏州园林'],
+ ['一颗新鲜的草莓特写,红色的外表,表面布满许多种子,背景是淡绿色的叶子'],
+ ['请将“杞人忧天”的样子画出来'],
+ ['枯藤老树昏鸦,小桥流水人家'],
+ ['湖水清澈,天空湛蓝,阳光灿烂。一只优雅的白天鹅在湖边游泳。它周围有几只小鸭子,看起来非常可爱,整个画面给人一种宁静祥和的感觉。'],
+ ['一朵鲜艳的红色玫瑰花,花瓣撒有一些水珠,晶莹剔透,特写镜头'],
+ ['臭豆腐'],
+ ['九寨沟'],
+ ['俗语“鲤鱼跃龙门”'],
+ ['风格是写实,画面主要描述一个亚洲戏曲艺术家正在表演,她穿着华丽的戏服,脸上戴着精致的面具,身姿优雅,背景是古色古香的舞台,镜头是近景'],
+ ],
+ [prompt],
+ label=strings['examples']
+ )
+ return block
+
+
+if __name__ == "__main__":
+ interface = ui()
+ interface.launch(server_name="0.0.0.0", server_port=443, share=True)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/app/multiTurnT2I_app.py b/PyTorch/built-in/mlm/HunyuanDiT/app/multiTurnT2I_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..540fdde236638e132a8774e69684febd36b1bdb7
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/app/multiTurnT2I_app.py
@@ -0,0 +1,270 @@
+# -- coding: utf-8 --
+#!/usr/bin/env python
+import gradio as gr
+from PIL import Image
+import sys
+import os
+sys.path.append(os.getcwd())
+import json
+import numpy as np
+from pathlib import Path
+import io
+import hashlib
+import requests
+import base64
+import pandas as pd
+from sample_t2i import inferencer
+from mllm.dialoggen_demo import init_dialoggen_model, eval_model
+
+SIZES = {
+ "正方形(square, 1024x1024)": (1024, 1024),
+ "风景(landscape, 1280x768)": (768, 1280),
+ "人像(portrait, 768x1280)": (1280, 768),
+}
+
+global_seed=np.random.randint(0, 10000)
+
+# Helper Functions
+def image_to_base64(image_path):
+ with open(image_path, "rb") as image_file:
+ encoded_image = base64.b64encode(image_file.read()).decode()
+ return encoded_image
+
+def get_strings(lang):
+ lang_file = Path(f"app/lang/{lang}.csv")
+ strings = pd.read_csv(lang_file, header=0)
+ strings = strings.set_index("key")['value'].to_dict()
+ return strings
+
+def get_image_md5(image):
+ image_data = io.BytesIO()
+ image.save(image_data, format="PNG")
+ image_data = image_data.getvalue()
+ md5_hash = hashlib.md5(image_data).hexdigest()
+ return md5_hash
+
+
+# mllm调用
+def request_dialogGen(server_url='http://0.0.0.0:8080',history_messages=[], question="画一个木制的鸟",image=""):
+ if image != "":
+ image = base64.b64encode(open(image, "rb").read()).decode()
+ print("history_messages before request",history_messages)
+ headers = {
+ 'accept': 'application/json',
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "text": question,
+ "image": image, # "image为空字符串,则进行文本对话"
+ "history": history_messages,
+ }
+ response = requests.post(server_url, headers=headers, json=data)
+ print("response",response)
+ response = response.json()
+ print(response)
+ response_text = response["result"]
+ history_messages = response["history"]
+ print("history_messages before request",history_messages)
+ return history_messages, response_text
+
+
+# 画图
+def image_generation(
+ prompt, infer_steps, seed, image_size
+):
+ print(f"prompt sent to T2I model: {prompt}, infer_steps: {infer_steps}, seed: {seed}, size: {image_size}")
+ height, width = SIZES[image_size]
+ results = gen.predict(prompt,
+ height=height,
+ width=width,
+ seed=seed,
+ infer_steps=infer_steps,
+ batch_size=1,
+ )
+ image = results['images'][0]
+ file_name = get_image_md5(image)
+ # Save images
+ save_dir = Path('results')
+ save_dir.mkdir(exist_ok=True)
+ save_path = f'results/multiRound_{file_name}.png'
+ image.save(save_path)
+ encoded_image = image_to_base64(save_path)
+
+ return encoded_image
+
+# 图文对话
+def chat(history_messages, input_text):
+
+ history_messages, response_text = request_dialogGen(history_messages=history_messages, question=input_text)
+ return history_messages, response_text
+#
+def pipeline(input_text, state, infer_steps, seed, image_size):
+
+ # 忽略空输入
+ if len(input_text) == 0:
+ return state, state[0]
+
+ conversation = state[0]
+ history_messages = state[1]
+
+ system_prompt = '请先判断用户的意图,若为画图则在输出前加入<画图>:'
+ print(f"input history:{history_messages}")
+ if not isinstance(history_messages, list) and len(history_messages.messages) >= 2:
+ response, history_messages = enhancer(input_text, return_history=True, history=history_messages, skip_special=True)
+ else:
+ response, history_messages = enhancer(input_text, return_history=True, history=history_messages, skip_special=False)
+
+ history_messages.messages[-1][-1] = response
+
+ if '<画图>' in response:
+ intention_draw = True
+ else:
+ intention_draw = False
+
+ print(f"response:{response}")
+ print("-" * 80)
+ print(f"history_messages:{history_messages}")
+ print(f"intention_draw:{intention_draw}")
+ if intention_draw:
+ prompt = response.split('<画图>')[-1]
+ # 画图
+ image_url = image_generation(prompt, infer_steps, seed, image_size)
+ response = f'
在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围 (At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere.)
+
在茂密的森林中,一只黑白相间的熊猫静静地坐在绿树红花中,周围是山川和海洋。背景是白天的森林,光线充足 (In the dense forest, a black and white panda sits quietly in green trees and red flowers, surrounded by mountains, rivers, and the ocean. The background is the forest in a bright environment.)
+
一位亚洲女性,身穿绿色上衣,戴着紫色头巾和紫色围巾,站在黑板前。背景是黑板。照片采用近景、平视和居中构图的方式呈现真实摄影风格 (An Asian woman, dressed in a green top, wearing a purple headscarf and a purple scarf, stands in front of a blackboard. The background is the blackboard. The photo is presented in a close-up, eye-level, and centered composition, adopting a realistic photographic style)
+
+
+
+
+
+
+
+
+
+
+
ControlNet Output
+
+
+
+
+
+
+
+
+
+
+
+
+### Training
+
+We utilize [**DWPose**](https://github.com/IDEA-Research/DWPose) for pose extraction. Please follow their guidelines to download the checkpoints and save them to `hydit/annotator/ckpts` directory. We provide serveral commands to quick install:
+```bash
+mkdir ./hydit/annotator/ckpts
+wget -O ./hydit/annotator/ckpts/dwpose.zip https://dit.hunyuan.tencent.com/download/HunyuanDiT/dwpose.zip
+unzip ./hydit/annotator/ckpts/dwpose.zip -d ./hydit/annotator/ckpts/
+```
+Additionally, ensure that you install the related dependencies.
+```bash
+pip install matplotlib==3.7.5
+pip install onnxruntime_gpu==1.16.3
+pip install opencv-python==4.8.1.78
+```
+
+
+We provide three types of weights for ControlNet training, `ema`, `module` and `distill`, and you can choose according to the actual effects. By default, we use `distill` weights.
+
+Here is an example, we load the `distill` weights into the main model and conduct ControlNet training.
+
+If you want to load the `module` weights into the main model, just remove the `--ema-to-module` parameter.
+
+If apply multiple resolution training, you need to add the `--multireso` and `--reso-step 64` parameter.
+
+```bash
+task_flag="canny_controlnet" # task flag is used to identify folders.
+control_type=canny
+resume=./ckpts/t2i/model/ # checkpoint root for resume
+index_file=path/to/your/index_file
+results_dir=./log_EXP # save root for results
+batch_size=1 # training batch size
+image_size=1024 # training image resolution
+grad_accu_steps=2 # gradient accumulation
+warmup_num_steps=0 # warm-up steps
+lr=0.0001 # learning rate
+ckpt_every=10000 # create a ckpt every a few steps.
+ckpt_latest_every=5000 # create a ckpt named `latest.pt` every a few steps.
+
+
+sh $(dirname "$0")/run_g_controlnet.sh \
+ --task-flag ${task_flag} \
+ --control-type ${control_type} \
+ --noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.03 \
+ --predict-type v_prediction \
+ --multireso \
+ --reso-step 64 \
+ --ema-to-module \
+ --uncond-p 0.44 \
+ --uncond-p-t5 0.44 \
+ --index-file ${index_file} \
+ --random-flip \
+ --lr ${lr} \
+ --batch-size ${batch_size} \
+ --image-size ${image_size} \
+ --global-seed 999 \
+ --grad-accu-steps ${grad_accu_steps} \
+ --warmup-num-steps ${warmup_num_steps} \
+ --use-flash-attn \
+ --use-fp16 \
+ --use-ema \
+ --ema-dtype fp32 \
+ --results-dir ${results_dir} \
+ --resume-split \
+ --resume ${resume} \
+ --ckpt-every ${ckpt_every} \
+ --ckpt-latest-every ${ckpt_latest_every} \
+ --log-every 10 \
+ --deepspeed \
+ --deepspeed-optimizer \
+ --use-zero-stage 2 \
+ "$@"
+```
+
+Recommended parameter settings
+
+| Parameter | Description | Recommended Parameter Value | Note|
+|:---------------:|:---------:|:---------------------------------------------------:|:--:|
+| `--batch-size` | Training batch size | 1 | Depends on GPU memory|
+| `--grad-accu-steps` | Size of gradient accumulation | 2 | - |
+| `--lr` | Learning rate | 0.0001 | - |
+| `--control-type` | ControlNet condition type, support 3 types now (canny, depth and pose) | / | - |
+
+
+### Inference
+You can use the following command line for inference.
+
+a. You can use a float to specify the weight for all layers, **or use a list to separately specify the weight for each layer**, for example, '[1.0 * (0.825 ** float(19 - i)) for i in range(19)]'
+```bash
+python3 sample_controlnet.py --control-weight [1.0 * (0.825 ** float(19 - i)) for i in range(19)] --no-enhance --load-key distill --infer-steps 50 --control-type canny --prompt "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围" --condition-image-path controlnet/asset/input/canny.jpg
+```
+
+b. Using canny ControlNet during inference
+
+```bash
+python3 sample_controlnet.py --no-enhance --load-key distill --infer-steps 50 --control-type canny --prompt "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围" --condition-image-path controlnet/asset/input/canny.jpg --control-weight 1.0
+```
+
+c. Using pose ControlNet during inference
+
+```bash
+python3 sample_controlnet.py --no-enhance --load-key distill --infer-steps 50 --control-type depth --prompt "在茂密的森林中,一只黑白相间的熊猫静静地坐在绿树红花中,周围是山川和海洋。背景是白天的森林,光线充足" --condition-image-path controlnet/asset/input/depth.jpg --control-weight 1.0
+```
+
+d. Using depth ControlNet during inference
+
+```bash
+python3 sample_controlnet.py --no-enhance --load-key distill --infer-steps 50 --control-type pose --prompt "一位亚洲女性,身穿绿色上衣,戴着紫色头巾和紫色围巾,站在黑板前。背景是黑板。照片采用近景、平视和居中构图的方式呈现真实摄影风格" --condition-image-path controlnet/asset/input/pose.jpg --control-weight 1.0
+```
+
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/dataset/yamls/porcelain.yaml b/PyTorch/built-in/mlm/HunyuanDiT/dataset/yamls/porcelain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..95e8c65ce52316778728c27d62967bbd31a6dc49
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/dataset/yamls/porcelain.yaml
@@ -0,0 +1,15 @@
+source:
+ - ./dataset/porcelain/arrows/00000.arrow
+
+filter:
+ column:
+ - name: height
+ type: int
+ action: ge
+ target: 1024
+ default: 1024
+ - name: width
+ type: int
+ action: ge
+ target: 1024
+ default: 1024
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/dataset/yamls/porcelain_mt.yaml b/PyTorch/built-in/mlm/HunyuanDiT/dataset/yamls/porcelain_mt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..96ca828f9cb3d30d3c6a6a4698dc8a369a27709d
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/dataset/yamls/porcelain_mt.yaml
@@ -0,0 +1,5 @@
+src:
+ - ./dataset/porcelain/jsons/porcelain.json
+base_size: 1024
+reso_step: 64
+min_size: 1024
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/diffusers/README.md b/PyTorch/built-in/mlm/HunyuanDiT/diffusers/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3a45574203ef91c31654cfa49e6c750cbfb9d658
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/diffusers/README.md
@@ -0,0 +1,131 @@
+# Hunyuan-DiT + 🤗 Diffusers
+
+You can use Hunyuan-DiT in 🤗 Diffusers library. Before using the pipelines, please install the latest version of 🤗 Diffusers with
+```bash
+pip install git+https://github.com/huggingface/diffusers.git
+```
+
+## Inference with th Base Model
+
+You can generate images with both Chinese and English prompts using the following Python script:
+```py
+import torch
+from diffusers import HunyuanDiTPipeline
+
+pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", torch_dtype=torch.float16)
+pipe.to("cuda")
+
+# You may also use English prompt as HunyuanDiT supports both English and Chinese
+# prompt = "An astronaut riding a horse"
+prompt = "一个宇航员在骑马"
+image = pipe(prompt).images[0]
+```
+You can use our distilled model to generate images even faster:
+
+```py
+import torch
+from diffusers import HunyuanDiTPipeline
+
+pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled", torch_dtype=torch.float16)
+pipe.to("cuda")
+
+# You may also use English prompt as HunyuanDiT supports both English and Chinese
+# prompt = "An astronaut riding a horse"
+prompt = "一个宇航员在骑马"
+image = pipe(prompt, num_inference_steps=25).images[0]
+```
+More details can be found in [HunyuanDiT-Diffusers-Distilled](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-Diffusers-Distilled)
+
+## LoRA
+LoRA can be integrated with Hunyuan-DiT inside the 🤗 Diffusers framework.
+The following example loads and uses the pre-trained LoRA. To try it, please start by downloading our pre-trained LoRA checkpoints,
+```bash
+huggingface-cli download Tencent-Hunyuan/HYDiT-LoRA --local-dir ./ckpts/t2i/lora
+```
+Then run the following code snippet to use the jade LoRA:
+```python
+import torch
+from diffusers import HunyuanDiTPipeline
+
+### convert checkpoint to diffusers format
+num_layers = 40
+def load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale):
+ for i in range(num_layers):
+ Wqkv = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_A.weight"])
+ q, k, v = torch.chunk(Wqkv, 3, dim=0)
+ transformer_state_dict[f"blocks.{i}.attn1.to_q.weight"] += lora_scale * q
+ transformer_state_dict[f"blocks.{i}.attn1.to_k.weight"] += lora_scale * k
+ transformer_state_dict[f"blocks.{i}.attn1.to_v.weight"] += lora_scale * v
+
+ out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_A.weight"])
+ transformer_state_dict[f"blocks.{i}.attn1.to_out.0.weight"] += lora_scale * out_proj
+
+ q_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_A.weight"])
+ transformer_state_dict[f"blocks.{i}.attn2.to_q.weight"] += lora_scale * q_proj
+
+ kv_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_A.weight"])
+ k, v = torch.chunk(kv_proj, 2, dim=0)
+ transformer_state_dict[f"blocks.{i}.attn2.to_k.weight"] += lora_scale * k
+ transformer_state_dict[f"blocks.{i}.attn2.to_v.weight"] += lora_scale * v
+
+ out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_A.weight"])
+ transformer_state_dict[f"blocks.{i}.attn2.to_out.0.weight"] += lora_scale * out_proj
+
+ q_proj = torch.matmul(lora_state_dict["pooler.q_proj.lora_B.weight"], lora_state_dict["pooler.q_proj.lora_A.weight"])
+ transformer_state_dict["time_extra_emb.pooler.q_proj.weight"] += lora_scale * q_proj
+
+ return transformer_state_dict
+
+### use the diffusers pipeline with lora
+pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", torch_dtype=torch.float16)
+pipe.to("cuda")
+
+from safetensors import safe_open
+
+lora_state_dict = {}
+with safe_open("./ckpts/t2i/lora/jade/adapter_model.safetensors", framework="pt", device=0) as f:
+ for k in f.keys():
+ lora_state_dict[k[17:]] = f.get_tensor(k) # remove 'basemodel.model'
+
+transformer_state_dict = pipe.transformer.state_dict()
+transformer_state_dict = load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale=1.0)
+pipe.transformer.load_state_dict(transformer_state_dict)
+
+prompt = "玉石绘画风格,一只猫在追蝴蝶"
+image = pipe(
+ prompt,
+ num_inference_steps=100,
+ guidance_scale=6.0,
+).images[0]
+image.save('img.png')
+```
+
+You can control the strength of LoRA by changing the `lora_scale` parameter.
+
+## ControlNet
+Hunyuan-DiT + ControlNet is supported in 🤗 Diffusers. The following example shows how to use Hunyuan-DiT + Canny ControlNet.
+```py
+from diffusers import HunyuanDiT2DControlNetModel, HunyuanDiTControlNetPipeline
+import torch
+controlnet = HunyuanDiT2DControlNetModel.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny", torch_dtype=torch.float16)
+
+pipe = HunyuanDiTControlNetPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16)
+pipe.to("cuda")
+
+from diffusers.utils import load_image
+cond_image = load_image('https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny/resolve/main/canny.jpg?download=true')
+
+## You may also use English prompt as HunyuanDiT supports both English and Chinese
+prompt="在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围"
+#prompt="At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere."
+image = pipe(
+ prompt,
+ height=1024,
+ width=1024,
+ control_image=cond_image,
+ num_inference_steps=50,
+).images[0]
+```
+
+There are other pre-trained ControlNets available. Please have a look at [the official huggingface website of Tencent Hunyuan Team](https://huggingface.co/Tencent-Hunyuan)
+
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/environment.yml b/PyTorch/built-in/mlm/HunyuanDiT/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f43b45b7fbd182dc2cb6c82dec2cbecc562b1a17
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/environment.yml
@@ -0,0 +1,8 @@
+name: HunyuanDiT
+channels:
+ - pytorch
+ - nvidia
+dependencies:
+ - python=3.8.12
+ - pytorch=1.13.1
+ - pip
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/example_prompts.txt b/PyTorch/built-in/mlm/HunyuanDiT/example_prompts.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f590be43c46e4c50f91a28f84821047241066df2
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/example_prompts.txt
@@ -0,0 +1,28 @@
+一只聪明的狐狸走在阔叶树林里, 旁边是一条小溪, 细节真实, 摄影
+湖水清澈,天空湛蓝,阳光灿烂。一只优雅的白天鹅在湖边游泳。它周围有几只小鸭子,看起来非常可爱,整个画面给人一种宁静祥和的感觉。
+太阳微微升起,花园里的玫瑰花瓣上露珠晶莹剔透,一只瓢虫正在爬向露珠,背景是清晨的花园,微距镜头
+一位女明星,中国人,头发是黑色,衣服是纯白色短袖,人物风格清新,城市背景
+后印象主义风格,一条古老的石板路上面散落着金黄色的树叶。路旁的风车在静谧地转动,后面竖着两个风车。背景是一片向日葵田,蓝天上飘着几朵白云
+一幅细致的油画描绘了一只年轻獾轻轻嗅着一朵明亮的黄色玫瑰时错综复杂的皮毛。背景是一棵大树干的粗糙纹理,獾的爪子轻轻地挖进树皮。在柔和的背景中,一个宁静的瀑布倾泻而下,它的水在绿色植物中闪烁着蓝色。
+渔舟唱晚
+请将杞人忧天的样子画出来
+一只长靴猫手持亮银色的宝剑,身着铠甲,眼神坚毅,站在一堆金币上,背景是暗色调的洞穴,图像上有金币的光影点缀。
+插画风格,一只狐狸和一只刺猬坐在水边的石头上,刺猬手里拿着一杯茶,狐狸旁边放着一个玻璃杯。周围是茂密的绿色植物和树木,阳光透过树叶洒在水面上,画面宁静温馨。
+泥塑风格,一座五彩斑斓的花园在画面中展现,各种各样的花朵,绿色的叶子和一只正在嬉戏的小猫形成了一幅生动的图像,背景是蓝天和白云
+枯藤老树昏鸦,小桥流水人家
+一张细致的照片捕捉到了一尊雕像的形象,这尊雕像酷似一位古代法老,头上出人意料地戴着一副青铜蒸汽朋克护目镜。这座雕像穿着复古时髦,一件清爽的白色T恤和一件合身的黑色皮夹克,与传统的头饰形成鲜明对比。背景是简单的纯色,突出了雕像的非传统服装和蒸汽朋克眼镜的复杂细节。
+一朵鲜艳的红色玫瑰花,花瓣撒有一些水珠,晶莹剔透,特写镜头,
+一只可爱的猫, 细节真实, 摄影
+飞流直下三千尺,疑是银河落九天
+成语“鲤鱼跃龙门”
+一颗新鲜的草莓特写,红色的外表,表面布满许多种子,背景是淡绿色的叶子
+九寨沟
+摄影风格,在画面中心是一盘热气腾腾的麻婆豆腐,豆腐呈白色,上面撒着一层红色的辣酱,有些许绿色的葱花点缀,背景是深色木质餐桌,桌子上放有辣椒和葱花作为点缀。
+一位年轻女子站在春季的火车站月台上。她身着蓝灰色长风衣,白色衬衫。她的深棕色头发扎成低马尾,几缕碎发随风飘扬。她的眼神充满期待,阳光洒在她温暖的脸庞上。
+一只优雅的白鹤在湖边静静地站立,它的身体纯白色,翅膀轻轻展开,背景是湖面和远处的山脉
+国画风格,苏州园林中的小桥流水,周围是郁郁葱葱的树,池塘里有几朵绽放的荷花,背景是宁静的江南水乡
+现实主义风格,画面主要描述一个巴洛克风格的花瓶,带有金色的装饰边框,花瓶上盛开着各种色彩鲜艳的花,白色背景
+醉后不知天在水,满船清梦压星河
+长城
+一个亚洲中年男士在夕阳下的公园长椅上静坐。他穿着一件深蓝色的针织毛衣和灰色裤子。他的头发略显花白,手中拿着一本敞开的书。面带微笑,眼神温和,周围是落日余晖和四周的绿树。
+风格是写实,画面主要描述一个亚洲戏曲艺术家正在表演,她穿着华丽的戏服,脸上戴着精致的面具,身姿优雅,背景是古色古香的舞台,镜头是近景
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/gpu_sample.sh b/PyTorch/built-in/mlm/HunyuanDiT/gpu_sample.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f7721e48aa57c29a43307c8d61075bb89412e2d3
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/gpu_sample.sh
@@ -0,0 +1 @@
+python sample_t2i.py --infer-mode fa --prompt "渔舟唱晚" --image-size 1280 768 --no-enhance
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/gpu_train.sh b/PyTorch/built-in/mlm/HunyuanDiT/gpu_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..95005fbcdcc032dfd647d57b054500c435722772
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/gpu_train.sh
@@ -0,0 +1,2 @@
+export CUBLAS_WORKSPACE_CONFIG=:4096:8
+PYTHONPATH=./ sh hydit/train.sh
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/__init__.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/__init__.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d17d0033b243cad48f26f5555d3dadac266a8f6
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/__init__.py
@@ -0,0 +1,143 @@
+# Openpose
+# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
+# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
+# 3rd Edited by ControlNet
+# 4th Edited by ControlNet (added face and correct hands)
+
+import os
+import random
+
+os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
+
+import torch
+import numpy as np
+from . import util
+from .wholebody import Wholebody
+
+
+def draw_pose(pose, H, W, draw_body=True):
+ bodies = pose['bodies']
+ faces = pose['faces']
+ hands = pose['hands']
+ candidate = bodies['candidate']
+ subset = bodies['subset']
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
+
+ if draw_body:
+ canvas = util.draw_bodypose(canvas, candidate, subset)
+
+ canvas = util.draw_handpose(canvas, hands)
+
+ canvas = util.draw_facepose(canvas, faces)
+
+ return canvas
+
+
+def keypoint2bbox(keypoints):
+ valid_keypoints = keypoints[keypoints[:, 0] >= 0] # Ignore keypoints with confidence 0
+ if len(valid_keypoints) == 0:
+ return np.zeros(4)
+ x_min, y_min = np.min(valid_keypoints, axis=0)
+ x_max, y_max = np.max(valid_keypoints, axis=0)
+
+ return np.array([x_min, y_min, x_max, y_max])
+
+def expand_bboxes(bboxes, expansion_rate=0.5, image_shape=(0, 0)):
+ expanded_bboxes = []
+ for bbox in bboxes:
+ x_min, y_min, x_max, y_max = map(int, bbox)
+ width = x_max - x_min
+ height = y_max - y_min
+
+ # 扩展宽度和高度
+ new_width = width * (1 + expansion_rate)
+ new_height = height * (1 + expansion_rate)
+
+ # 计算新的边界框坐标
+ x_min_new = max(0, x_min - (new_width - width) / 2)
+ x_max_new = min(image_shape[1], x_max + (new_width - width) / 2)
+ y_min_new = max(0, y_min - (new_height - height) / 2)
+ y_max_new = min(image_shape[0], y_max + (new_height - height) / 2)
+
+ expanded_bboxes.append([x_min_new, y_min_new, x_max_new, y_max_new])
+
+ return expanded_bboxes
+
+def create_mask(image_width, image_height, bboxs):
+ mask = np.zeros((image_height, image_width), dtype=np.float32)
+ for bbox in bboxs:
+ x1, y1, x2, y2 = map(int, bbox)
+ mask[y1:y2+1, x1:x2+1] = 1.0
+ return mask
+
+threshold = 0.4
+class DWposeDetector:
+ def __init__(self):
+
+ self.pose_estimation = Wholebody()
+
+ def __call__(self, oriImg, return_index=False, return_yolo=False, return_mask=False):
+ oriImg = oriImg.copy()
+ H, W, C = oriImg.shape
+ with torch.no_grad():
+ candidate, subset = self.pose_estimation(oriImg)
+ candidate = np.zeros((1, 134, 2), dtype=np.float32) if candidate is None else candidate
+ subset = np.zeros((1, 134), dtype=np.float32) if subset is None else subset
+ nums, keys, locs = candidate.shape
+ candidate[..., 0] /= float(W)
+ candidate[..., 1] /= float(H)
+ # import pdb; pdb.set_trace()
+ if return_yolo:
+ candidate[subset < threshold] = -0.1
+ subset = np.expand_dims(subset >= threshold, axis=-1)
+ keypoint = np.concatenate([candidate, subset], axis=-1)
+
+ # return pose + hand
+ return np.concatenate([keypoint[:, :18], keypoint[:, 92:]], axis=1)
+
+ body = candidate[:, :18].copy()
+ body = body.reshape(nums * 18, locs)
+ score = subset[:, :18]
+ for i in range(len(score)):
+ for j in range(len(score[i])):
+ if score[i][j] > threshold:
+ score[i][j] = int(18 * i + j)
+ else:
+ score[i][j] = -1
+
+ un_visible = subset < threshold
+ candidate[un_visible] = -1
+
+ foot = candidate[:, 18:24]
+
+ faces = candidate[:, 24:92]
+
+ hands1 = candidate[:, 92:113]
+ hands2 = candidate[:, 113:]
+ hands = np.vstack([hands1, hands2])
+
+ # import pdb; pdb.set_trace()
+ hands_ = hands[hands.max(axis=(1, 2)) > 0]
+ if len(hands_) == 0:
+ bbox = np.array([0, 0, 0, 0]).astype(int)
+ else:
+ hand_random = random.choice(hands_)
+ bbox = (keypoint2bbox(hand_random) * H).astype(int) # [0, 1] -> [h, w]
+
+
+
+ bodies = dict(candidate=body, subset=score)
+ pose = dict(bodies=bodies, hands=hands, faces=faces)
+
+ if return_mask:
+ bbox = [(keypoint2bbox(hand) * H).astype(int) for hand in hands_]
+ # bbox = expand_bboxes(bbox, expansion_rate=0.5, image_shape=(H, W))
+ mask = create_mask(W, H, bbox)
+ return draw_pose(pose, H, W), mask
+
+ if return_index:
+ return pose
+ else:
+ return draw_pose(pose, H, W), bbox
+
+
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/onnxdet.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/onnxdet.py
new file mode 100644
index 0000000000000000000000000000000000000000..48056f54338de43cbd3430339b957922f37fc8c8
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/onnxdet.py
@@ -0,0 +1,129 @@
+import cv2
+import numpy as np
+
+import onnxruntime
+
+
+def nms(boxes, scores, nms_thr):
+ """Single class NMS implemented in Numpy."""
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+ inds = np.where(ovr <= nms_thr)[0]
+ order = order[inds + 1]
+
+ return keep
+
+
+def multiclass_nms(boxes, scores, nms_thr, score_thr):
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
+ final_dets = []
+ num_classes = scores.shape[1]
+ for cls_ind in range(num_classes):
+ cls_scores = scores[:, cls_ind]
+ valid_score_mask = cls_scores > score_thr
+ if valid_score_mask.sum() == 0:
+ continue
+ else:
+ valid_scores = cls_scores[valid_score_mask]
+ valid_boxes = boxes[valid_score_mask]
+ keep = nms(valid_boxes, valid_scores, nms_thr)
+ if len(keep) > 0:
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
+ dets = np.concatenate(
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
+ )
+ final_dets.append(dets)
+ if len(final_dets) == 0:
+ return None
+ return np.concatenate(final_dets, 0)
+
+
+def demo_postprocess(outputs, img_size, p6=False):
+ grids = []
+ expanded_strides = []
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
+
+ hsizes = [img_size[0] // stride for stride in strides]
+ wsizes = [img_size[1] // stride for stride in strides]
+
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
+ grids.append(grid)
+ shape = grid.shape[:2]
+ expanded_strides.append(np.full((*shape, 1), stride))
+
+ grids = np.concatenate(grids, 1)
+ expanded_strides = np.concatenate(expanded_strides, 1)
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
+
+ return outputs
+
+
+def preprocess(img, input_size, swap=(2, 0, 1)):
+ if len(img.shape) == 3:
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
+ else:
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
+
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
+ resized_img = cv2.resize(
+ img,
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
+ interpolation=cv2.INTER_LINEAR,
+ ).astype(np.uint8)
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
+
+ padded_img = padded_img.transpose(swap)
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
+ return padded_img, r
+
+
+def inference_detector(session, oriImg):
+ input_shape = (640, 640)
+ img, ratio = preprocess(oriImg, input_shape)
+
+ ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
+ output = session.run(None, ort_inputs)
+ predictions = demo_postprocess(output[0], input_shape)[0]
+
+ boxes = predictions[:, :4]
+ scores = predictions[:, 4:5] * predictions[:, 5:]
+
+ boxes_xyxy = np.ones_like(boxes)
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.
+ boxes_xyxy /= ratio
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
+ if dets is not None:
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
+ isscore = final_scores > 0.3
+ iscat = final_cls_inds == 0
+ isbbox = [i and j for (i, j) in zip(isscore, iscat)]
+ final_boxes = final_boxes[isbbox]
+ return final_boxes
+ else:
+ return None
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/onnxpose.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/onnxpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..79cd4a06241123af81ea22446a4ca8816716443f
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/onnxpose.py
@@ -0,0 +1,360 @@
+from typing import List, Tuple
+
+import cv2
+import numpy as np
+import onnxruntime as ort
+
+def preprocess(
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Do preprocessing for RTMPose model inference.
+
+ Args:
+ img (np.ndarray): Input image in shape.
+ input_size (tuple): Input image size in shape (w, h).
+
+ Returns:
+ tuple:
+ - resized_img (np.ndarray): Preprocessed image.
+ - center (np.ndarray): Center of image.
+ - scale (np.ndarray): Scale of image.
+ """
+ # get shape of image
+ img_shape = img.shape[:2]
+ out_img, out_center, out_scale = [], [], []
+ if len(out_bbox) == 0:
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
+ for i in range(len(out_bbox)):
+ x0 = out_bbox[i][0]
+ y0 = out_bbox[i][1]
+ x1 = out_bbox[i][2]
+ y1 = out_bbox[i][3]
+ bbox = np.array([x0, y0, x1, y1])
+
+ # get center and scale
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
+
+ # do affine transformation
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
+
+ # normalize image
+ mean = np.array([123.675, 116.28, 103.53])
+ std = np.array([58.395, 57.12, 57.375])
+ resized_img = (resized_img - mean) / std
+
+ out_img.append(resized_img)
+ out_center.append(center)
+ out_scale.append(scale)
+
+ return out_img, out_center, out_scale
+
+
+def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
+ """Inference RTMPose model.
+
+ Args:
+ sess (ort.InferenceSession): ONNXRuntime session.
+ img (np.ndarray): Input image in shape.
+
+ Returns:
+ outputs (np.ndarray): Output of RTMPose model.
+ """
+ all_out = []
+ # build input
+ for i in range(len(img)):
+ input = [img[i].transpose(2, 0, 1)]
+
+ # build output
+ sess_input = {sess.get_inputs()[0].name: input}
+ sess_output = []
+ for out in sess.get_outputs():
+ sess_output.append(out.name)
+
+ # run model
+ outputs = sess.run(sess_output, sess_input)
+ all_out.append(outputs)
+
+ return all_out
+
+
+def postprocess(outputs: List[np.ndarray],
+ model_input_size: Tuple[int, int],
+ center: Tuple[int, int],
+ scale: Tuple[int, int],
+ simcc_split_ratio: float = 2.0
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Postprocess for RTMPose model output.
+
+ Args:
+ outputs (np.ndarray): Output of RTMPose model.
+ model_input_size (tuple): RTMPose model Input image size.
+ center (tuple): Center of bbox in shape (x, y).
+ scale (tuple): Scale of bbox in shape (w, h).
+ simcc_split_ratio (float): Split ratio of simcc.
+
+ Returns:
+ tuple:
+ - keypoints (np.ndarray): Rescaled keypoints.
+ - scores (np.ndarray): Model predict scores.
+ """
+ all_key = []
+ all_score = []
+ for i in range(len(outputs)):
+ # use simcc to decode
+ simcc_x, simcc_y = outputs[i]
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
+
+ # rescale keypoints
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
+ all_key.append(keypoints[0])
+ all_score.append(scores[0])
+
+ return np.array(all_key), np.array(all_score)
+
+
+def bbox_xyxy2cs(bbox: np.ndarray,
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
+
+ Args:
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
+ as (left, top, right, bottom)
+ padding (float): BBox padding factor that will be multilied to scale.
+ Default: 1.0
+
+ Returns:
+ tuple: A tuple containing center and scale.
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
+ (n, 2)
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
+ (n, 2)
+ """
+ # convert single bbox from (4, ) to (1, 4)
+ dim = bbox.ndim
+ if dim == 1:
+ bbox = bbox[None, :]
+
+ # get bbox center and scale
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
+
+ if dim == 1:
+ center = center[0]
+ scale = scale[0]
+
+ return center, scale
+
+
+def _fix_aspect_ratio(bbox_scale: np.ndarray,
+ aspect_ratio: float) -> np.ndarray:
+ """Extend the scale to match the given aspect ratio.
+
+ Args:
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
+ aspect_ratio (float): The ratio of ``w/h``
+
+ Returns:
+ np.ndarray: The reshaped image scale in (2, )
+ """
+ w, h = np.hsplit(bbox_scale, [1])
+ bbox_scale = np.where(w > h * aspect_ratio,
+ np.hstack([w, w / aspect_ratio]),
+ np.hstack([h * aspect_ratio, h]))
+ return bbox_scale
+
+
+def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
+ """Rotate a point by an angle.
+
+ Args:
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
+ angle_rad (float): rotation angle in radian
+
+ Returns:
+ np.ndarray: Rotated point in shape (2, )
+ """
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
+ return rot_mat @ pt
+
+
+def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
+ """To calculate the affine matrix, three pairs of points are required. This
+ function is used to get the 3rd point, given 2D points a & b.
+
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
+ anticlockwise, using b as the rotation center.
+
+ Args:
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
+
+ Returns:
+ np.ndarray: The 3rd point.
+ """
+ direction = a - b
+ c = b + np.r_[-direction[1], direction[0]]
+ return c
+
+
+def get_warp_matrix(center: np.ndarray,
+ scale: np.ndarray,
+ rot: float,
+ output_size: Tuple[int, int],
+ shift: Tuple[float, float] = (0., 0.),
+ inv: bool = False) -> np.ndarray:
+ """Calculate the affine transformation matrix that can warp the bbox area
+ in the input image to the output size.
+
+ Args:
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
+ scale (np.ndarray[2, ]): Scale of the bounding box
+ wrt [width, height].
+ rot (float): Rotation angle (degree).
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
+ destination heatmaps.
+ shift (0-100%): Shift translation ratio wrt the width/height.
+ Default (0., 0.).
+ inv (bool): Option to inverse the affine transform direction.
+ (inv=False: src->dst or inv=True: dst->src)
+
+ Returns:
+ np.ndarray: A 2x3 transformation matrix
+ """
+ shift = np.array(shift)
+ src_w = scale[0]
+ dst_w = output_size[0]
+ dst_h = output_size[1]
+
+ # compute transformation matrix
+ rot_rad = np.deg2rad(rot)
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
+ dst_dir = np.array([0., dst_w * -0.5])
+
+ # get four corners of the src rectangle in the original image
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = center + scale * shift
+ src[1, :] = center + src_dir + scale * shift
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
+
+ # get four corners of the dst rectangle in the input image
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
+
+ if inv:
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return warp_mat
+
+
+def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
+ img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Get the bbox image as the model input by affine transform.
+
+ Args:
+ input_size (dict): The input size of the model.
+ bbox_scale (dict): The bbox scale of the img.
+ bbox_center (dict): The bbox center of the img.
+ img (np.ndarray): The original image.
+
+ Returns:
+ tuple: A tuple containing center and scale.
+ - np.ndarray[float32]: img after affine transform.
+ - np.ndarray[float32]: bbox scale after affine transform.
+ """
+ w, h = input_size
+ warp_size = (int(w), int(h))
+
+ # reshape bbox to fixed aspect ratio
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
+
+ # get the affine matrix
+ center = bbox_center
+ scale = bbox_scale
+ rot = 0
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
+
+ # do affine transform
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
+
+ return img, bbox_scale
+
+
+def get_simcc_maximum(simcc_x: np.ndarray,
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Get maximum response location and value from simcc representations.
+
+ Note:
+ instance number: N
+ num_keypoints: K
+ heatmap height: H
+ heatmap width: W
+
+ Args:
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
+
+ Returns:
+ tuple:
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
+ (K, 2) or (N, K, 2)
+ - vals (np.ndarray): values of maximum heatmap responses in shape
+ (K,) or (N, K)
+ """
+ N, K, Wx = simcc_x.shape
+ simcc_x = simcc_x.reshape(N * K, -1)
+ simcc_y = simcc_y.reshape(N * K, -1)
+
+ # get maximum value locations
+ x_locs = np.argmax(simcc_x, axis=1)
+ y_locs = np.argmax(simcc_y, axis=1)
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
+ max_val_x = np.amax(simcc_x, axis=1)
+ max_val_y = np.amax(simcc_y, axis=1)
+
+ # get maximum value across x and y axis
+ mask = max_val_x > max_val_y
+ max_val_x[mask] = max_val_y[mask]
+ vals = max_val_x
+ locs[vals <= 0.] = -1
+
+ # reshape
+ locs = locs.reshape(N, K, 2)
+ vals = vals.reshape(N, K)
+
+ return locs, vals
+
+
+def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
+ simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
+ """Modulate simcc distribution with Gaussian.
+
+ Args:
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
+ simcc_split_ratio (int): The split ratio of simcc.
+
+ Returns:
+ tuple: A tuple containing center and scale.
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
+ """
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
+ keypoints /= simcc_split_ratio
+
+ return keypoints, scores
+
+
+def inference_pose(session, out_bbox, oriImg):
+ h, w = session.get_inputs()[0].shape[2:]
+ model_input_size = (w, h)
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
+ outputs = inference(session, resized_img)
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
+
+ return keypoints, scores
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/util.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1453ab66d3d7a2f161aa6fd45acf70729990863
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/util.py
@@ -0,0 +1,298 @@
+import math
+import numpy as np
+import matplotlib
+import cv2
+
+
+eps = 0.01
+
+
+def smart_resize(x, s):
+ Ht, Wt = s
+ if x.ndim == 2:
+ Ho, Wo = x.shape
+ Co = 1
+ else:
+ Ho, Wo, Co = x.shape
+ if Co == 3 or Co == 1:
+ k = float(Ht + Wt) / float(Ho + Wo)
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
+ else:
+ return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
+
+
+def smart_resize_k(x, fx, fy):
+ if x.ndim == 2:
+ Ho, Wo = x.shape
+ Co = 1
+ else:
+ Ho, Wo, Co = x.shape
+ Ht, Wt = Ho * fy, Wo * fx
+ if Co == 3 or Co == 1:
+ k = float(Ht + Wt) / float(Ho + Wo)
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
+ else:
+ return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
+
+
+def padRightDownCorner(img, stride, padValue):
+ h = img.shape[0]
+ w = img.shape[1]
+
+ pad = 4 * [None]
+ pad[0] = 0 # up
+ pad[1] = 0 # left
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
+
+ img_padded = img
+ pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
+ pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
+ pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
+ pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
+
+ return img_padded, pad
+
+
+def transfer(model, model_weights):
+ transfered_model_weights = {}
+ for weights_name in model.state_dict().keys():
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
+ return transfered_model_weights
+
+
+def draw_bodypose(canvas, candidate, subset):
+ H, W, C = canvas.shape
+ candidate = np.array(candidate)
+ subset = np.array(subset)
+
+ stickwidth = 4
+
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+
+ for i in range(17):
+ for n in range(len(subset)):
+ index = subset[n][np.array(limbSeq[i]) - 1]
+ if -1 in index:
+ continue
+ Y = candidate[index.astype(int), 0] * float(W)
+ X = candidate[index.astype(int), 1] * float(H)
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
+ # import pdb; pdb.set_trace()
+ cv2.fillConvexPoly(canvas, polygon, colors[i])
+
+ canvas = (canvas * 0.6).astype(np.uint8)
+
+ for i in range(18):
+ for n in range(len(subset)):
+ index = int(subset[n][i])
+ if index == -1:
+ continue
+ x, y = candidate[index][0:2]
+ x = int(x * W)
+ y = int(y * H)
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
+
+ return canvas
+
+
+def draw_handpose(canvas, all_hand_peaks):
+ H, W, C = canvas.shape
+
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
+
+ for peaks in all_hand_peaks:
+ peaks = np.array(peaks)
+
+ for ie, e in enumerate(edges):
+ x1, y1 = peaks[e[0]]
+ x2, y2 = peaks[e[1]]
+ x1 = int(x1 * W)
+ y1 = int(y1 * H)
+ x2 = int(x2 * W)
+ y2 = int(y2 * H)
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
+
+ for i, keyponit in enumerate(peaks):
+ x, y = keyponit
+ x = int(x * W)
+ y = int(y * H)
+ if x > eps and y > eps:
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
+ return canvas
+
+
+def draw_facepose(canvas, all_lmks):
+ H, W, C = canvas.shape
+ for lmks in all_lmks:
+ lmks = np.array(lmks)
+ for lmk in lmks:
+ x, y = lmk
+ x = int(x * W)
+ y = int(y * H)
+ if x > eps and y > eps:
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
+ return canvas
+
+
+# detect hand according to body pose keypoints
+# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
+def handDetect(candidate, subset, oriImg):
+ # right hand: wrist 4, elbow 3, shoulder 2
+ # left hand: wrist 7, elbow 6, shoulder 5
+ ratioWristElbow = 0.33
+ detect_result = []
+ image_height, image_width = oriImg.shape[0:2]
+ for person in subset.astype(int):
+ # if any of three not detected
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
+ if not (has_left or has_right):
+ continue
+ hands = []
+ #left hand
+ if has_left:
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
+ x1, y1 = candidate[left_shoulder_index][:2]
+ x2, y2 = candidate[left_elbow_index][:2]
+ x3, y3 = candidate[left_wrist_index][:2]
+ hands.append([x1, y1, x2, y2, x3, y3, True])
+ # right hand
+ if has_right:
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
+ x1, y1 = candidate[right_shoulder_index][:2]
+ x2, y2 = candidate[right_elbow_index][:2]
+ x3, y3 = candidate[right_wrist_index][:2]
+ hands.append([x1, y1, x2, y2, x3, y3, False])
+
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
+ x = x3 + ratioWristElbow * (x3 - x2)
+ y = y3 + ratioWristElbow * (y3 - y2)
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
+ # x-y refers to the center --> offset to topLeft point
+ # handRectangle.x -= handRectangle.width / 2.f;
+ # handRectangle.y -= handRectangle.height / 2.f;
+ x -= width / 2
+ y -= width / 2 # width = height
+ # overflow the image
+ if x < 0: x = 0
+ if y < 0: y = 0
+ width1 = width
+ width2 = width
+ if x + width > image_width: width1 = image_width - x
+ if y + width > image_height: width2 = image_height - y
+ width = min(width1, width2)
+ # the max hand box value is 20 pixels
+ if width >= 20:
+ detect_result.append([int(x), int(y), int(width), is_left])
+
+ '''
+ return value: [[x, y, w, True if left hand else False]].
+ width=height since the network require squared input.
+ x, y is the coordinate of top left
+ '''
+ return detect_result
+
+
+# Written by Lvmin
+def faceDetect(candidate, subset, oriImg):
+ # left right eye ear 14 15 16 17
+ detect_result = []
+ image_height, image_width = oriImg.shape[0:2]
+ for person in subset.astype(int):
+ has_head = person[0] > -1
+ if not has_head:
+ continue
+
+ has_left_eye = person[14] > -1
+ has_right_eye = person[15] > -1
+ has_left_ear = person[16] > -1
+ has_right_ear = person[17] > -1
+
+ if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
+ continue
+
+ head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
+
+ width = 0.0
+ x0, y0 = candidate[head][:2]
+
+ if has_left_eye:
+ x1, y1 = candidate[left_eye][:2]
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 3.0)
+
+ if has_right_eye:
+ x1, y1 = candidate[right_eye][:2]
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 3.0)
+
+ if has_left_ear:
+ x1, y1 = candidate[left_ear][:2]
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 1.5)
+
+ if has_right_ear:
+ x1, y1 = candidate[right_ear][:2]
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 1.5)
+
+ x, y = x0, y0
+
+ x -= width
+ y -= width
+
+ if x < 0:
+ x = 0
+
+ if y < 0:
+ y = 0
+
+ width1 = width * 2
+ width2 = width * 2
+
+ if x + width > image_width:
+ width1 = image_width - x
+
+ if y + width > image_height:
+ width2 = image_height - y
+
+ width = min(width1, width2)
+
+ if width >= 20:
+ detect_result.append([int(x), int(y), int(width)])
+
+ return detect_result
+
+
+# get max index of 2d array
+def npmax(array):
+ arrayindex = array.argmax(1)
+ arrayvalue = array.max(1)
+ i = arrayvalue.argmax()
+ j = arrayindex[i]
+ return i, j
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/wholebody.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/wholebody.py
new file mode 100644
index 0000000000000000000000000000000000000000..108fa670ceaee0a263291b87674032ea3b663066
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/dwpose/wholebody.py
@@ -0,0 +1,56 @@
+import os
+import cv2
+import numpy as np
+
+import onnxruntime as ort
+from .onnxdet import inference_detector
+from .onnxpose import inference_pose
+
+class Wholebody:
+ def __init__(self):
+ rank = int(os.getenv('LOCAL_RANK', '0'))
+ device = f'cuda:{rank}'
+ providers = ['CPUExecutionProvider'] if device == 'cpu' else [("CUDAExecutionProvider", {"device_id": rank})]
+ onnx_det = 'hydit/annotator/ckpts/yolox_l.onnx'
+ onnx_pose = 'hydit/annotator/ckpts/dw-ll_ucoco_384.onnx'
+
+ self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
+ self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
+
+ def __call__(self, oriImg):
+ det_result = inference_detector(self.session_det, oriImg)
+ if det_result is None:
+ return None, None
+ keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
+
+ keypoints_info = np.concatenate(
+ (keypoints, scores[..., None]), axis=-1)
+ # compute neck joint
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
+ # neck score when visualizing pred
+ neck[:, 2:4] = np.logical_and(
+ keypoints_info[:, 5, 2:4] > 0.3,
+ keypoints_info[:, 6, 2:4] > 0.3).astype(int)
+ new_keypoints_info = np.insert(
+ keypoints_info, 17, neck, axis=1)
+ mmpose_idx = [
+ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
+ ]
+ openpose_idx = [
+ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
+ ]
+ new_keypoints_info[:, openpose_idx] = \
+ new_keypoints_info[:, mmpose_idx]
+ keypoints_info = new_keypoints_info
+
+ keypoints, scores = keypoints_info[
+ ..., :2], keypoints_info[..., 2]
+
+ return keypoints, scores
+
+ def to(self, device):
+ self.session_det.set_providers([device])
+ self.session_pose.set_providers([device])
+ return self
+
+
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/glyph.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/glyph.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bb67813897e90e051a218fcf157ee7261dc59b5
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/glyph.py
@@ -0,0 +1,249 @@
+# MIT License
+# Copyright (c) 2023 AIGText
+# https://github.com/AIGText/GlyphControl-release
+
+from PIL import Image, ImageFont, ImageDraw
+import random
+import numpy as np
+import cv2
+
+
+# resize height to image_height first, then shrink or pad to image_width
+def resize_and_pad_image(pil_image, image_size):
+ if isinstance(image_size, (tuple, list)) and len(image_size) == 2:
+ image_width, image_height = image_size
+ elif isinstance(image_size, int):
+ image_width = image_height = image_size
+ else:
+ raise ValueError(f"Image size should be int or list/tuple of int not {image_size}")
+
+ while pil_image.size[1] >= 2 * image_height:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_height / pil_image.size[1]
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
+
+ # shrink
+ if pil_image.size[0] > image_width:
+ pil_image = pil_image.resize((image_width, image_height), resample=Image.BICUBIC)
+
+ # padding
+ if pil_image.size[0] < image_width:
+ img = Image.new(mode="RGBA", size=(image_width, image_height), color=(255, 255, 255, 0))
+ width, _ = pil_image.size
+ img.paste(pil_image, ((image_width - width) // 2, 0))
+ pil_image = img
+
+ return pil_image
+
+
+def resize_and_pad_image2(pil_image, image_size):
+ if isinstance(image_size, (tuple, list)) and len(image_size) == 2:
+ image_width, image_height = image_size
+ elif isinstance(image_size, int):
+ image_width = image_height = image_size
+ else:
+ raise ValueError(f"Image size should be int or list/tuple of int not {image_size}")
+
+ while pil_image.size[1] >= 2 * image_height:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_height / pil_image.size[1]
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
+
+ # shrink
+ if pil_image.size[0] > image_width:
+ pil_image = pil_image.resize((image_width, image_height), resample=Image.BICUBIC)
+
+ # padding
+ if pil_image.size[0] < image_width:
+ img = Image.new(mode="RGB", size=(image_width, image_height), color="white")
+ width, _ = pil_image.size
+ img.paste(pil_image, ((image_width - width) // 2, 0))
+ pil_image = img
+
+ return pil_image
+
+
+def draw_visual_text(image_size, bboxes, rendered_txt_values, num_rows_values=None, align="center"):
+ # aligns = ["center", "left", "right"]
+ """Render text image based on the glyph instructions, i.e., the list of tuples (text, bbox, num_rows).
+ Currently we just use Calibri font to render glyph images.
+ """
+ # print(image_size, bboxes, rendered_txt_values, num_rows_values, align)
+ background = Image.new("RGB", image_size, "white")
+ font = ImageFont.truetype("simfang.ttf", encoding='utf-8', size=512)
+ if num_rows_values is None:
+ num_rows_values = [1] * len(rendered_txt_values)
+
+ text_list = []
+ for text, bbox, num_rows in zip(rendered_txt_values, bboxes, num_rows_values):
+
+ if len(text) == 0:
+ continue
+
+ text = text.strip()
+ if num_rows != 1:
+ word_tokens = text.split()
+ num_tokens = len(word_tokens)
+ index_list = range(1, num_tokens + 1)
+ if num_tokens > num_rows:
+ index_list = random.sample(index_list, num_rows)
+ index_list.sort()
+ line_list = []
+ start_idx = 0
+ for index in index_list:
+ line_list.append(
+ " ".join(word_tokens
+ [start_idx: index]
+ )
+ )
+ start_idx = index
+ text = "\n".join(line_list)
+
+ if 'ratio' not in bbox or bbox['ratio'] == 0 or bbox['ratio'] < 1e-4:
+ image4ratio = Image.new("RGB", (512, 512), "white")
+ draw = ImageDraw.Draw(image4ratio)
+ _, _, w, h = draw.textbbox(xy=(0, 0), text=text, font=font)
+ ratio = w / h
+ else:
+ ratio = bbox['ratio']
+
+ width = int(bbox['width'] * image_size[1])
+ height = int(width / ratio)
+ top_left_x = int(bbox['top_left_x'] * image_size[0])
+ top_left_y = int(bbox['top_left_y'] * image_size[1])
+ yaw = bbox['yaw']
+
+ text_image = Image.new("RGB", (512, 512), "white")
+ draw = ImageDraw.Draw(text_image)
+ x, y, w, h = draw.textbbox(xy=(0, 0), text=text, font=font)
+ text_image = Image.new("RGBA", (w, h), (255, 255, 255, 0))
+ draw = ImageDraw.Draw(text_image)
+ draw.text((-x / 2, -y / 2), text, (0, 0, 0, 255), font=font, align=align)
+
+ text_image_ = resize_and_pad_image2(text_image.convert('RGB'), (288, 48))
+ # import pdb; pdb.set_trace()
+ text_list.append(np.array(text_image_))
+
+ text_image = resize_and_pad_image(text_image, (width, height))
+ text_image = text_image.rotate(angle=-yaw, expand=True, fillcolor=(255, 255, 255, 0))
+ # image = Image.new("RGB", (w, h), "white")
+ # draw = ImageDraw.Draw(image)
+ background.paste(text_image, (top_left_x, top_left_y), mask=text_image)
+
+ return background, text_list
+
+
+# [{'width': 0.1601562201976776, 'ratio': 81.99999451637203, 'yaw': 0.0, 'top_left_x': 0.712890625, 'top_left_y': 0.0},
+# {'width': 0.134765625, 'ratio': 34.5, 'yaw': 0.0, 'top_left_x': 0.4453125, 'top_left_y': 0.0},
+
+
+def insert_spaces(string, nSpace):
+ if nSpace == 0:
+ return string
+ new_string = ""
+ for char in string:
+ new_string += char + " " * nSpace
+ return new_string[:-nSpace]
+
+
+def draw_glyph(text, font='simfang.ttf'):
+ if isinstance(font, str):
+ font = ImageFont.truetype(font, encoding='utf-8', size=512)
+ g_size = 50
+ W, H = (512, 80)
+ new_font = font.font_variant(size=g_size)
+ img = Image.new(mode='1', size=(W, H), color=0)
+ draw = ImageDraw.Draw(img)
+ left, top, right, bottom = new_font.getbbox(text)
+ text_width = max(right-left, 5)
+ text_height = max(bottom - top, 5)
+ ratio = min(W*0.9/text_width, H*0.9/text_height)
+ new_font = font.font_variant(size=int(g_size*ratio))
+
+ text_width, text_height = new_font.getsize(text)
+ offset_x, offset_y = new_font.getoffset(text)
+ x = (img.width - text_width) // 2
+ y = (img.height - text_height) // 2 - offset_y//2
+ draw.text((x, y), text, font=new_font, fill='white')
+ img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
+
+ return img
+
+
+def draw_glyph2(text, polygon, font='simfang.ttf', vertAng=10, scale=1, width=1024, height=1024, add_space=True):
+ if isinstance(font, str):
+ font = ImageFont.truetype(font, encoding='utf-8', size=60)
+ enlarge_polygon = polygon*scale
+ rect = cv2.minAreaRect(enlarge_polygon)
+ box = cv2.boxPoints(rect)
+ box = np.int0(box)
+ w, h = rect[1]
+ angle = rect[2]
+ if angle < -45:
+ angle += 90
+ angle = -angle
+ if w < h:
+ angle += 90
+
+ vert = False
+ if (abs(angle) % 90 < vertAng or abs(90-abs(angle) % 90) % 90 < vertAng):
+ _w = max(box[:, 0]) - min(box[:, 0])
+ _h = max(box[:, 1]) - min(box[:, 1])
+ if _h >= _w:
+ vert = True
+ angle = 0
+
+ img = np.zeros((height*scale, width*scale, 3), np.uint8)
+ img = Image.fromarray(img)
+
+ # infer font size
+ image4ratio = Image.new("RGB", img.size, "white")
+ draw = ImageDraw.Draw(image4ratio)
+ _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
+ text_w = min(w, h) * (_tw / _th)
+ if text_w <= max(w, h):
+ # add space
+ if len(text) > 1 and not vert and add_space:
+ for i in range(1, 100):
+ text_space = insert_spaces(text, i)
+ _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)
+ if min(w, h) * (_tw2 / _th2) > max(w, h):
+ break
+ text = insert_spaces(text, i-1)
+ font_size = min(w, h)*0.80
+ else:
+ # shrink = 0.75 if vert else 0.85
+ shrink = 1.0
+ font_size = min(w, h) / (text_w/max(w, h)) * shrink
+ new_font = font.font_variant(size=int(font_size))
+
+ left, top, right, bottom = new_font.getbbox(text)
+ text_width = right-left
+ text_height = bottom - top
+
+ layer = Image.new('RGBA', img.size, (0, 0, 0, 0))
+ draw = ImageDraw.Draw(layer)
+ if not vert:
+ draw.text((rect[0][0]-text_width//2, rect[0][1]-text_height//2-top), text, font=new_font, fill=(255, 255, 255, 255))
+ else:
+ x_s = min(box[:, 0]) + _w//2 - text_height//2
+ y_s = min(box[:, 1])
+ for c in text:
+ draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
+ _, _t, _, _b = new_font.getbbox(c)
+ y_s += _b
+
+ rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))
+
+ x_offset = int((img.width - rotated_layer.width) / 2)
+ y_offset = int((img.height - rotated_layer.height) / 2)
+ img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
+ img = np.expand_dims(np.array(img.convert('1')), axis=2).astype(np.float64)
+
+ return img
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/util.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0b217ef9adf92dd5b1fe0debcfb07d0f241a4cb
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/annotator/util.py
@@ -0,0 +1,98 @@
+import random
+
+import numpy as np
+import cv2
+import os
+
+
+annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
+
+
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+
+def resize_image(input_image, resolution):
+ H, W, C = input_image.shape
+ H = float(H)
+ W = float(W)
+ k = float(resolution) / min(H, W)
+ H *= k
+ W *= k
+ H = int(np.round(H / 64.0)) * 64
+ W = int(np.round(W / 64.0)) * 64
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
+ return img
+
+
+def nms(x, t, s):
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
+
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+
+ y = np.zeros_like(x)
+
+ for f in [f1, f2, f3, f4]:
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+
+ z = np.zeros_like(y, dtype=np.uint8)
+ z[y > t] = 255
+ return z
+
+
+def make_noise_disk(H, W, C, F):
+ noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
+ noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
+ noise = noise[F: F + H, F: F + W]
+ noise -= np.min(noise)
+ noise /= np.max(noise)
+ if C == 1:
+ noise = noise[:, :, None]
+ return noise
+
+
+def min_max_norm(x):
+ x -= np.min(x)
+ x /= np.maximum(np.max(x), 1e-5)
+ return x
+
+
+def safe_step(x, step=2):
+ y = x.astype(np.float32) * float(step + 1)
+ y = y.astype(np.int32).astype(np.float32) / float(step)
+ return y
+
+
+def img2mask(img, H, W, low=10, high=90):
+ assert img.ndim == 3 or img.ndim == 2
+ assert img.dtype == np.uint8
+
+ if img.ndim == 3:
+ y = img[:, :, random.randrange(0, img.shape[2])]
+ else:
+ y = img
+
+ y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
+
+ if random.uniform(0, 1) < 0.5:
+ y = 255 - y
+
+ return y < np.percentile(y, random.randrange(low, high))
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/config.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3310ae5e29869f0878eeac1ad2a3fa4d89ee3b3a
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/config.py
@@ -0,0 +1,213 @@
+import argparse
+
+from .constants import *
+from .modules.models import HUNYUAN_DIT_CONFIG, HUNYUAN_DIT_MODELS
+from .diffusion.gaussian_diffusion import ModelVarType
+
+import deepspeed
+
+
+def model_var_type(value):
+ try:
+ return ModelVarType[value]
+ except KeyError:
+ raise ValueError(f"Invalid choice '{value}', valid choices are {[v.name for v in ModelVarType]}")
+
+
+def get_args(default_args=None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--task-flag", type=str)
+
+ # General Setting
+ parser.add_argument("--batch-size", type=int, default=1, help="Per-GPU batch size")
+ parser.add_argument('--seed', type=int, default=42, help="A seed for all the prompts.")
+ parser.add_argument("--use-fp16", action="store_true", help="Use FP16 precision.")
+ parser.add_argument("--no-fp16", dest="use_fp16", action="store_false")
+ parser.set_defaults(use_fp16=True)
+ parser.add_argument("--extra-fp16", action="store_true", help="Use extra fp16 for vae and text_encoder.")
+
+ # HunYuan-DiT
+ parser.add_argument("--model", type=str, choices=list(HUNYUAN_DIT_CONFIG.keys()), default='DiT-g/2')
+ parser.add_argument("--image-size", type=int, nargs='+', default=[1024, 1024],
+ help='Image size (h, w). If a single value is provided, the image will be treated to '
+ '(value, value).')
+ parser.add_argument("--qk-norm", action="store_true",
+ help="Query Key normalization. See http://arxiv.org/abs/2302.05442 for details.")
+ parser.set_defaults(qk_norm=True)
+ parser.add_argument("--norm", type=str, choices=["rms", "layer"], default="layer", help="Normalization layer type")
+ parser.add_argument("--text-states-dim", type=int, default=1024, help="Hidden size of CLIP text encoder.")
+ parser.add_argument("--text-len", type=int, default=77, help="Token length of CLIP text encoder output.")
+ parser.add_argument("--text-states-dim-t5", type=int, default=2048, help="Hidden size of CLIP text encoder.")
+ parser.add_argument("--text-len-t5", type=int, default=256, help="Token length of T5 text encoder output.")
+
+ # LoRA config
+ parser.add_argument("--training-parts", type=str, default='all', choices=['all', 'lora'], help="Training parts")
+ parser.add_argument("--rank", type=int, default=64, help="Rank of LoRA")
+ parser.add_argument("--lora-ckpt", type=str, default=None, help="LoRA checkpoint")
+ parser.add_argument('--target-modules', type=str, nargs='+', default=['Wqkv', 'q_proj', 'kv_proj', 'out_proj'],
+ help="Target modules for LoRA fine tune")
+ parser.add_argument("--output-merge-path", type=str, default=None, help="Output path for merged model")
+
+ # controlnet config
+ parser.add_argument("--control-type", type=str, default='canny', choices=['canny', 'depth', 'pose'],
+ help="Controlnet condition type")
+ parser.add_argument("--control-weight", type=float, default=1.0,
+ help="Controlnet weight, You can use a float to specify the weight for all layers, or use a list to separately specify the weight for each layer, for example, '[1.0 * (0.825 ** float(19 - i)) for i in range(19)]'")
+ parser.add_argument("--condition-image-path", type=str, default=None, help="Inference condition image path")
+
+ # Diffusion
+ parser.add_argument("--learn-sigma", action="store_true", help="Learn extra channels for sigma.")
+ parser.add_argument("--no-learn-sigma", dest="learn_sigma", action="store_false")
+ parser.set_defaults(learn_sigma=True)
+ parser.add_argument("--predict-type", type=str, choices=list(PREDICT_TYPE), default="v_prediction",
+ help="Diffusion predict type")
+ parser.add_argument("--noise-schedule", type=str, choices=list(NOISE_SCHEDULES), default="scaled_linear",
+ help="Noise schedule")
+ parser.add_argument("--beta-start", type=float, default=0.00085, help="Beta start value")
+ parser.add_argument("--beta-end", type=float, default=0.03, help="Beta end value")
+ parser.add_argument("--sigma-small", action="store_true")
+ parser.add_argument("--mse-loss-weight-type", type=str, default="constant",
+ help="Min-SNR-gamma. Can be constant or min_snr_ where gamma is a integer. 5 is recommended in the paper.")
+ parser.add_argument("--model-var-type", type=model_var_type, default=None, help="Specify the model variable type.")
+ parser.add_argument("--noise-offset", type=float, default=0.0, help="Add extra noise to the input image.")
+
+ # ========================================================================================================
+ # Inference
+ # ========================================================================================================
+
+ # Basic Setting
+ parser.add_argument("--prompt", type=str, default="一只小猫", help="The prompt for generating images.")
+ parser.add_argument("--model-root", type=str, default="ckpts", help="Model root path.")
+
+ # Model setting
+ parser.add_argument("--load-key", type=str, choices=["ema", "module", "distill", 'merge'], default="ema",
+ help="Load model key for HunYuanDiT checkpoint.")
+ parser.add_argument('--size-cond', type=int, nargs='+', default=[1024, 1024],
+ help="Size condition used in sampling. 2 values are required for height and width. "
+ "If a single value is provided, the image will be treated to (value, value).")
+ parser.add_argument('--target-ratios', type=str, nargs='+', default=None,
+ help="Target ratios for multi-resolution training.")
+ parser.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.")
+ parser.add_argument("--negative", type=str, default=None, help="Negative prompt.")
+
+ # Acceleration
+ parser.add_argument("--infer-mode", type=str, choices=["fa", "torch", "trt"], default="torch",
+ help="Inference mode")
+ parser.add_argument("--onnx-workdir", type=str, default="onnx_model", help="Path to save ONNX model")
+
+ # Sampling
+ parser.add_argument("--sampler", type=str, choices=SAMPLER_FACTORY, default="ddpm", help="Diffusion sampler")
+ parser.add_argument("--infer-steps", type=int, default=100, help="Inference steps")
+
+ # Prompt enhancement
+ parser.add_argument("--enhance", action="store_true", help="Enhance prompt with mllm.")
+ parser.add_argument("--no-enhance", dest="enhance", action="store_false")
+ parser.add_argument("--load-4bit", help="load DialogGen model with 4bit quantization.", action="store_true")
+ parser.set_defaults(enhance=True)
+
+ # App
+ parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="Language")
+
+ # ========================================================================================================
+ # Training
+ # ========================================================================================================
+
+ # Basic Setting
+ parser.add_argument("--lr", type=float, default=1e-4)
+ parser.add_argument("--epochs", type=int, default=1400)
+ parser.add_argument("--max-training-steps", type=int, default=10_000_000)
+ parser.add_argument("--gc-interval", type=int, default=40,
+ help='To address the memory bottleneck encountered during the preprocessing of the dataset,'
+ ' memory fragments are reclaimed here by invoking the gc.collect() function.')
+ parser.add_argument("--log-every", type=int, default=100)
+ parser.add_argument("--ckpt-every", type=int, default=100_000, help="Create a ckpt every a few steps.")
+ parser.add_argument("--ckpt-latest-every", type=int, default=10_000,
+ help="Create a ckpt named `latest.pt` every a few steps.")
+ parser.add_argument("--num-workers", type=int, default=4)
+ parser.add_argument("--global-seed", type=int, default=1234)
+ parser.add_argument("--warmup-min-lr", type=float, default=1e-6)
+ parser.add_argument("--warmup-num-steps", type=float, default=0)
+ parser.add_argument("--weight-decay", type=float, default=0, help="weight-decay in optimizer")
+ parser.add_argument("--rope-img", type=str, default=None, choices=['extend', 'base512', 'base1024'],
+ help="Extend or interpolate the positional embedding of the image.")
+ parser.add_argument("--rope-real", action="store_true",
+ help="Use real part and imaginary part separately for RoPE.")
+
+ # Classifier-free
+ parser.add_argument("--uncond-p", type=float, default=0.2,
+ help="The probability of dropping training text used for CLIP feature extraction")
+ parser.add_argument("--uncond-p-t5", type=float, default=0.2,
+ help="The probability of dropping training text used for mT5 feature extraction")
+
+ # Directory
+ parser.add_argument("--results-dir", type=str, default="results")
+ parser.add_argument("--resume", type=str, default=None, help="Resume experiment from a checkpoint")
+ parser.add_argument("--no-strict", dest="strict", action="store_false", help="Strict loading of checkpoint")
+ parser.set_defaults(strict=True)
+ parser.add_argument("--resume-deepspeed", action="store_true",
+ help="Resume model and ema states from a checkpoint saved by Deepspeed version of DIT.")
+ parser.add_argument("--resume-split", action="store_true",
+ help="Resume model and ema states from two checkpoint separated from DeepSpeed ckpt.")
+ parser.add_argument("--ema-to-module", action="store_true",
+ help="If true, initialize the module with EMA weights.")
+ parser.add_argument("--module-to-ema", action="store_true",
+ help="if true, initialize the ema with Module weights.")
+
+ # Dataset
+ parser.add_argument("--index-file", type=str, nargs='+',
+ help="During training, provide a JSON file with data indices.")
+ parser.add_argument("--random-flip", action="store_true", help="Random flip image")
+ parser.add_argument("--reset-loader", action="store_true",
+ help="Reset the data loader. It is useful when resuming from a checkpoint but switch to a new dataset.")
+ parser.add_argument("--multireso", action="store_true", help="Use multi-resolution training.")
+ parser.add_argument("--reso-step", type=int, default=None, help="Step size for multi-resolution training.")
+
+ # Additional condition
+ parser.add_argument("--random-shrink-size-cond", action="store_true",
+ help="Randomly shrink the original size condition.")
+ parser.add_argument("--merge-src-cond", action="store_true", help="Merge the source condition into a single value.")
+
+ # EMA Model
+ parser.add_argument("--use-ema", action="store_true", help="Use EMA model")
+ parser.add_argument("--ema-dtype", type=str, choices=['fp16', 'fp32', 'none'], default="none",
+ help="EMA data type. If none, use the same data type as the model.")
+ parser.add_argument("--ema-decay", type=float, default=None,
+ help="EMA decay rate. If None, use the default value of the model.")
+ parser.add_argument("--ema-warmup", action="store_true",
+ help="EMA warmup. If True, perform ema_decay warmup from 0 to ema_decay.")
+ parser.add_argument("--ema-warmup-power", type=float, default=None,
+ help="EMA power. If None, use the default value of the model.")
+ parser.add_argument("--ema-reset-decay", action="store_true",
+ help="Reset EMA decay to 0 and restart increasing the EMA decay."
+ "Only works when --ema-warmup is enabled.")
+ # Acceleration
+ parser.add_argument("--use-flash-attn", action="store_true", help="During training, "
+ "flash attention is used to accelerate training.")
+ parser.add_argument("--no-flash-attn", dest="use_flash_attn",
+ action="store_false",
+ help="During training, flash attention is not used to accelerate training.")
+ parser.add_argument("--use-zero-stage", type=int, default=1, help="Use AngelPTM zero stage. Support 2 and 3")
+ parser.add_argument("--grad-accu-steps", type=int, default=1, help="Gradient accumulation steps.")
+
+ # ========================================================================================================
+ # Deepspeed config
+ # ========================================================================================================
+ parser = deepspeed.add_config_arguments(parser)
+ parser.add_argument('--local_rank', type=int, default=None,
+ help='local rank passed from distributed launcher.')
+ parser.add_argument('--deepspeed-optimizer', action='store_true',
+ help='Switching to the optimizers in DeepSpeed')
+ parser.add_argument('--remote-device', type=str, default='none', choices=['none', 'cpu', 'nvme'],
+ help='Remote device for ZeRO-3 initialized parameters.')
+ parser.add_argument('--zero-stage', type=int, default=1)
+ parser.add_argument("--async-ema", action="store_true", help="Whether to use multi stream to excut EMA.")
+
+ # 模型迁移新增参数
+ parser.add_argument("--cal-e2e", dest="cal_e2e", action="store_true")
+ parser.add_argument("--profiling", dest="profiling", action="store_true")
+ parser.add_argument("--seed-all", dest="seed_all", action="store_true")
+ parser.add_argument("--autocast-dtype", type=str, default="bf16", choices=['bf16', 'fp16'], help="the dtype to autocast")
+
+ args = parser.parse_args(default_args)
+
+ return args
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/constants.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..a46921578127b3c81ffea3d1bfb78d06a35421af
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/constants.py
@@ -0,0 +1,82 @@
+import torch
+
+# =======================================================
+
+NOISE_SCHEDULES = {
+ "linear",
+ "scaled_linear",
+ "squaredcos_cap_v2",
+}
+
+PREDICT_TYPE = {
+ "epsilon",
+ "sample",
+ "v_prediction",
+}
+
+# =======================================================
+
+NEGATIVE_PROMPT = '错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,'
+
+# =======================================================
+TRT_MAX_BATCH_SIZE = 1
+TRT_MAX_WIDTH = 1280
+TRT_MAX_HEIGHT = 1280
+
+# =======================================================
+# Constants about models
+# =======================================================
+
+VAE_EMA_PATH = "ckpts/t2i/sdxl-vae-fp16-fix"
+TOKENIZER = "ckpts/t2i/tokenizer"
+TEXT_ENCODER = 'ckpts/t2i/clip_text_encoder'
+T5_ENCODER = {
+ 'MT5': 'ckpts/t2i/mt5',
+ 'attention_mask': True,
+ 'layer_index': -1,
+ 'attention_pool': True,
+ 'torch_dtype': torch.float16,
+ 'learnable_replace': True
+}
+
+SAMPLER_FACTORY = {
+ 'ddpm': {
+ 'scheduler': 'DDPMScheduler',
+ 'name': 'DDPM',
+ 'kwargs': {
+ 'steps_offset': 1,
+ 'clip_sample': False,
+ 'clip_sample_range': 1.0,
+ 'beta_schedule': 'scaled_linear',
+ 'beta_start': 0.00085,
+ 'beta_end': 0.03,
+ 'prediction_type': 'v_prediction',
+ }
+ },
+ 'ddim': {
+ 'scheduler': 'DDIMScheduler',
+ 'name': 'DDIM',
+ 'kwargs': {
+ 'steps_offset': 1,
+ 'clip_sample': False,
+ 'clip_sample_range': 1.0,
+ 'beta_schedule': 'scaled_linear',
+ 'beta_start': 0.00085,
+ 'beta_end': 0.03,
+ 'prediction_type': 'v_prediction',
+ }
+ },
+ 'dpmms': {
+ 'scheduler': 'DPMSolverMultistepScheduler',
+ 'name': 'DPMMS',
+ 'kwargs': {
+ 'beta_schedule': 'scaled_linear',
+ 'beta_start': 0.00085,
+ 'beta_end': 0.03,
+ 'prediction_type': 'v_prediction',
+ 'trained_betas': None,
+ 'solver_order': 2,
+ 'algorithm_type': 'dpmsolver++',
+ }
+ },
+}
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/data_loader/__init__.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/data_loader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/data_loader/arrow_load_stream.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/data_loader/arrow_load_stream.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac2bd17ae04b0959ced080fae0c3345123bcc0e7
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/data_loader/arrow_load_stream.py
@@ -0,0 +1,255 @@
+import pickle
+import random
+from pathlib import Path
+import ast
+import numpy as np
+import re
+import json
+import time
+from functools import partial
+from PIL import Image
+
+import torch
+import torchvision.transforms as T
+import torch.nn.functional as F
+from torchvision.transforms import functional as TF
+from torch.utils.data import Dataset
+
+from IndexKits.index_kits import ArrowIndexV2, MultiResolutionBucketIndexV2, MultiIndexV2
+
+
+class TextImageArrowStream(Dataset):
+ def __init__(self,
+ args,
+ resolution=512,
+ random_flip=None,
+ enable_CN=True,
+ log_fn=print,
+ index_file=None,
+ multireso=False,
+ batch_size=-1,
+ world_size=1,
+ random_shrink_size_cond=False,
+ merge_src_cond=False,
+ uncond_p=0.0,
+ text_ctx_len=77,
+ tokenizer=None,
+ uncond_p_t5=0.0,
+ text_ctx_len_t5=256,
+ tokenizer_t5=None,
+ ):
+ self.args = args
+ self.resolution = resolution
+ self.log_fn = lambda x: log_fn(f" {Path(__file__).stem} | " + x)
+
+ self.random_flip = random_flip
+ # If true, the Chinese prompt from the `text_zh` column will be taken from the arrow file;
+ # otherwise, the English prompt from the `text_en` column will be taken,
+ # provided that `text_zh` or `text_en` exists in the arrow file.
+ self.enable_CN = enable_CN
+ self.index_file = index_file
+ self.multireso = multireso
+ self.batch_size = batch_size
+ self.world_size = world_size
+ self.index_manager = self.load_index()
+
+ # clip params
+ self.uncond_p = uncond_p
+ self.text_ctx_len = text_ctx_len
+ self.tokenizer = tokenizer
+
+ # t5 params
+ self.uncond_p_t5 = uncond_p_t5
+ self.text_ctx_len_t5 = text_ctx_len_t5
+ self.tokenizer_t5 = tokenizer_t5
+
+ # size condition
+ self.random_shrink_size_cond = random_shrink_size_cond
+ self.merge_src_cond = merge_src_cond
+
+ assert isinstance(resolution, int), f"resolution must be an integer, got {resolution}"
+ self.flip_norm = T.Compose(
+ [
+ T.RandomHorizontalFlip() if self.random_flip else T.Lambda(lambda x: x),
+ T.ToTensor(),
+ T.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ # show info
+ if self.merge_src_cond:
+ self.log_fn("Enable merging src condition: (oriW, oriH) --> ((WH)**0.5, (WH)**0.5)")
+
+ self.log_fn("Enable image_meta_size condition (original_size, target_size, crop_coords)")
+ self.log_fn(f"Image_transforms: {self.flip_norm}")
+
+ def load_index(self):
+ multireso = self.multireso
+ index_file = self.index_file
+ batch_size = self.batch_size
+ world_size = self.world_size
+
+ if multireso:
+ if isinstance(index_file, (list, tuple)):
+ if len(index_file) > 1:
+ raise ValueError(f"When enabling multireso, index_file should be a single file, but got {index_file}")
+ index_file = index_file[0]
+ index_manager = MultiResolutionBucketIndexV2(index_file, batch_size, world_size)
+ self.log_fn(f"Using MultiResolutionBucketIndexV2: {len(index_manager):,}")
+ else:
+ if isinstance(index_file, str):
+ index_file = [index_file]
+ if len(index_file) == 1:
+ index_manager = ArrowIndexV2(index_file[0])
+ self.log_fn(f"Using ArrowIndexV2: {len(index_manager):,}")
+ else:
+ index_manager = MultiIndexV2(index_file)
+ self.log_fn(f"Using MultiIndexV2: {len(index_manager):,}")
+
+ return index_manager
+
+ def shuffle(self, seed, fast=False):
+ self.index_manager.shuffle(seed, fast=fast)
+
+ def get_raw_image(self, index, image_key="image"):
+ try:
+ ret = self.index_manager.get_image(index, image_key)
+ except Exception as e:
+ self.log_fn(f'get_raw_image | Error: {e}')
+ ret = Image.new("RGB", (256, 256), (255, 255, 255))
+ return ret
+
+ @staticmethod
+ def random_crop_image(image, origin_size, target_size):
+ aspect_ratio = float(origin_size[0]) / float(origin_size[1])
+ if origin_size[0] < origin_size[1]:
+ new_width = target_size[0]
+ new_height = int(new_width / aspect_ratio)
+ else:
+ new_height = target_size[1]
+ new_width = int(new_height * aspect_ratio)
+
+ image = image.resize((new_width, new_height), Image.LANCZOS)
+
+ if new_width > target_size[0]:
+ x_start = random.randint(0, new_width - target_size[0])
+ y_start = 0
+ else:
+ x_start = 0
+ y_start = random.randint(0, new_height - target_size[1])
+ image_crop = image.crop((x_start, y_start, x_start + target_size[0], y_start + target_size[1]))
+ crops_coords_top_left = (x_start, y_start)
+ return image_crop, crops_coords_top_left
+
+ def get_style(self, index):
+ "Here we use a default learned embedder layer for future extension."
+ style = 0
+ return style
+
+ def get_image_with_hwxy(self, index, image_key="image"):
+
+ image = self.get_raw_image(index, image_key=image_key)
+ origin_size = image.size
+
+ if self.multireso:
+ target_size = self.index_manager.get_target_size(index)
+ image, crops_coords_top_left = self.index_manager.resize_and_crop(
+ image, target_size, resample=Image.LANCZOS, crop_type='random')
+ image_tensor = self.flip_norm(image)
+ else:
+ target_size = (self.resolution, self.resolution)
+ image_crop, crops_coords_top_left = self.random_crop_image(image, origin_size, target_size)
+ image_tensor = self.flip_norm(image_crop)
+
+ if self.random_shrink_size_cond:
+ origin_size = (1024 if origin_size[0] < 1024 else origin_size[0],
+ 1024 if origin_size[1] < 1024 else origin_size[1])
+ if self.merge_src_cond:
+ val = (origin_size[0] * origin_size[1]) ** 0.5
+ origin_size = (val, val)
+
+ image_meta_size = tuple(origin_size) + tuple(target_size) + tuple(crops_coords_top_left)
+ kwargs = {
+ 'image_meta_size': image_meta_size,
+ }
+
+ style = self.get_style(index)
+ kwargs['style'] = style
+
+ return image_tensor, kwargs
+
+ def get_text_info_with_encoder(self, description):
+ pad_num = 0
+ text_inputs = self.tokenizer(
+ description,
+ padding="max_length",
+ max_length=self.text_ctx_len,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids[0]
+ attention_mask = text_inputs.attention_mask[0].bool()
+ if pad_num > 0:
+ attention_mask[1:pad_num + 1] = False
+ return description, text_input_ids, attention_mask
+
+ def fill_t5_token_mask(self, fill_tensor, fill_number, setting_length):
+ fill_length = setting_length - fill_tensor.shape[1]
+ if fill_length > 0:
+ fill_tensor = torch.cat((fill_tensor, fill_number * torch.ones(1, fill_length)), dim=1)
+ return fill_tensor
+
+ def get_text_info_with_encoder_t5(self, description_t5):
+ text_tokens_and_mask = self.tokenizer_t5(
+ description_t5,
+ max_length=self.text_ctx_len_t5,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors='pt'
+ )
+ text_input_ids_t5=self.fill_t5_token_mask(text_tokens_and_mask["input_ids"], fill_number=1, setting_length=self.text_ctx_len_t5).long()
+ attention_mask_t5=self.fill_t5_token_mask(text_tokens_and_mask["attention_mask"], fill_number=0, setting_length=self.text_ctx_len_t5).bool()
+ return description_t5, text_input_ids_t5, attention_mask_t5
+
+ def get_original_text(self, ind):
+ text = self.index_manager.get_attribute(ind, 'text_zh' if self.enable_CN else 'text_en')
+ text = str(text).strip()
+ return text
+
+ def get_text(self, ind):
+ text = self.get_original_text(ind)
+ if text == '':
+ text = '随机生成一张图片'
+ return text
+
+ def __getitem__(self, ind):
+ # Get text
+ if random.random() < self.uncond_p:
+ description = ""
+ else:
+ description = self.get_text(ind)
+
+ # Get text for t5
+ if random.random() < self.uncond_p_t5:
+ description_t5 = ""
+ else:
+ description_t5 = self.get_text(ind)
+
+ original_pil_image, kwargs = self.get_image_with_hwxy(ind)
+
+ # Use encoder to embed tokens online
+ text, text_embedding, text_embedding_mask = self.get_text_info_with_encoder(description)
+
+ text_t5, text_embedding_t5, text_embedding_mask_t5 = self.get_text_info_with_encoder_t5(description_t5)
+ return (
+ original_pil_image,
+ text_embedding.clone().detach(),
+ text_embedding_mask.clone().detach(),
+ text_embedding_t5.clone().detach(),
+ text_embedding_mask_t5.clone().detach(),
+ {k: torch.tensor(np.array(v)).clone().detach() for k, v in kwargs.items()},
+ )
+
+ def __len__(self):
+ return len(self.index_manager)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/data_loader/csv2arrow.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/data_loader/csv2arrow.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ffe81d7810d645297364ec13c40e1d62edc4088
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/data_loader/csv2arrow.py
@@ -0,0 +1,92 @@
+# -*- coding: utf-8 -*-
+import datetime
+import gc
+import os
+import time
+from multiprocessing import Pool
+import subprocess
+import pandas as pd
+import pyarrow as pa
+from tqdm import tqdm
+import hashlib
+from PIL import Image
+import sys
+
+
+def parse_data(data):
+ try:
+ img_path = data[0]
+
+ with open(img_path, "rb") as fp:
+ image = fp.read()
+ md5 = hashlib.md5(image).hexdigest()
+
+ with Image.open(img_path) as f:
+ width, height = f.size
+
+ return [data[1], md5, width, height, image]
+
+ except Exception as e:
+ print(f'error: {e}')
+ return
+
+def make_arrow(csv_root, dataset_root, start_id=0, end_id=-1):
+ print(csv_root)
+ arrow_dir = dataset_root
+ print(arrow_dir)
+
+ if not os.path.exists(arrow_dir):
+ os.makedirs(arrow_dir)
+
+ data = pd.read_csv(csv_root)
+ data = data[["image_path", "text_zh"]]
+ columns_list = data.columns.tolist()
+ columns_list.append("image")
+
+ if end_id < 0:
+ end_id = len(data)
+ print(f'start_id:{start_id} end_id:{end_id}')
+ data = data[start_id:end_id]
+ num_slice = 5000
+ start_sub = int(start_id / num_slice)
+ sub_len = int(len(data) // num_slice) # if int(len(data) // num_slice) else 1
+ subs = list(range(sub_len + 1))
+ for sub in tqdm(subs):
+ arrow_path = os.path.join(arrow_dir, '{}.arrow'.format(str(sub + start_sub).zfill(5)))
+ if os.path.exists(arrow_path):
+ continue
+ print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} start {sub + start_sub}")
+
+ sub_data = data[sub * num_slice: (sub + 1) * num_slice].values
+
+ bs = pool.map(parse_data, sub_data)
+ bs = [b for b in bs if b]
+ print(f'length of this arrow:{len(bs)}')
+
+ columns_list = ["text_zh", "md5", "width", "height", "image"]
+ dataframe = pd.DataFrame(bs, columns=columns_list)
+ table = pa.Table.from_pandas(dataframe)
+
+ os.makedirs(dataset_root, exist_ok=True)
+ with pa.OSFile(arrow_path, "wb") as sink:
+ with pa.RecordBatchFileWriter(sink, table.schema) as writer:
+ writer.write_table(table)
+ del dataframe
+ del table
+ del bs
+ gc.collect()
+
+
+if __name__ == '__main__':
+
+ if len(sys.argv) != 4:
+ print("Usage: python hydit/data_loader/csv2arrow.py ${csv_root} ${output_arrow_data_path} ${pool_num}")
+ print("csv_root: The path to your created CSV file. For more details, see https://github.com/Tencent/HunyuanDiT?tab=readme-ov-file#truck-training")
+ print("output_arrow_data_path: The path for storing the created Arrow file")
+ print("pool_num: The number of processes, used for multiprocessing. If you encounter memory issues, you can set pool_num to 1")
+ sys.exit(1)
+ csv_root = sys.argv[1]
+ output_arrow_data_path = sys.argv[2]
+ pool_num = int(sys.argv[3])
+ pool = Pool(pool_num)
+ make_arrow(csv_root, output_arrow_data_path)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/__init__.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6a681113da5c6fc6a1cc216e2270171c4780898
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/__init__.py
@@ -0,0 +1,49 @@
+from . import gaussian_diffusion as gd
+from .respace import SpacedDiffusion, space_timesteps
+
+
+def create_diffusion(
+ *,
+ steps=1000,
+ learn_sigma=True,
+ sigma_small=False,
+ noise_schedule="linear",
+ use_kl=False,
+ predict_type='epsilon',
+ rescale_timesteps=False,
+ rescale_learned_sigmas=False,
+ timestep_respacing="",
+ mse_loss_weight_type='constant',
+ beta_start=0.0001,
+ beta_end=0.02,
+ noise_offset=0.0,
+):
+ betas = gd.get_named_beta_schedule(noise_schedule, steps, beta_start, beta_end)
+ if use_kl:
+ loss_type = gd.LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = gd.LossType.RESCALED_MSE
+ else:
+ loss_type = gd.LossType.MSE
+ if timestep_respacing is None or timestep_respacing == "":
+ timestep_respacing = [steps]
+ mean_type = gd.predict_type_dict[predict_type]
+
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=mean_type,
+ model_var_type=(
+ (
+ gd.ModelVarType.FIXED_LARGE
+ if not sigma_small
+ else gd.ModelVarType.FIXED_SMALL
+ )
+ if not learn_sigma
+ else gd.ModelVarType.LEARNED_RANGE
+ ),
+ loss_type=loss_type,
+ rescale_timesteps=rescale_timesteps,
+ mse_loss_weight_type=mse_loss_weight_type,
+ noise_offset=noise_offset,
+ )
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/diffusion_utils.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c62d5b0f914eb8bef789d3612e3c48d68ec0cef
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/diffusion_utils.py
@@ -0,0 +1,69 @@
+import torch as th
+import numpy as np
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, th.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + th.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
+ )
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = th.where(
+ x < -0.999,
+ log_cdf_plus,
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
+
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/gaussian_diffusion.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/gaussian_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..636d0f678cc585832f1ad09873928b5be94f62fd
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/gaussian_diffusion.py
@@ -0,0 +1,1384 @@
+import math
+
+import numpy as np
+import torch as th
+import enum
+
+from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
+from ..utils.tools import assert_shape
+
+from utils.npu_utils import is_npu_available
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+ VELOCITY = enum.auto() # the model predicts v
+
+
+predict_type_dict = {
+ 'epsilon': ModelMeanType.EPSILON,
+ 'sample': ModelMeanType.START_X,
+ 'v_prediction': ModelMeanType.VELOCITY,
+}
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
+ return betas
+
+
+def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
+ """
+ This is the deprecated API for creating beta schedules.
+ See get_named_beta_schedule() for the new library of schedules.
+ """
+ if beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start ** 0.5,
+ beta_end ** 0.5,
+ num_diffusion_timesteps,
+ dtype=np.float64,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "warmup10":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
+ elif beta_schedule == "warmup50":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
+ )
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert_shape(betas, (num_diffusion_timesteps,))
+ return betas
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, beta_start=0.0001, beta_end=0.02):
+ """
+ Get a pre-defined beta schedule for the given name.
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ return get_beta_schedule(
+ "linear",
+ beta_start=scale * beta_start, # DDPM
+ beta_end=scale * beta_end, # DDPM
+ num_diffusion_timesteps=num_diffusion_timesteps, # DDPM
+ )
+ elif schedule_name == "scaled_linear":
+ return get_beta_schedule(
+ "quad",
+ beta_start=beta_start, # StableDiffusion, should be 0.00085
+ beta_end=beta_end, # StableDiffusion, should be 0.012
+ num_diffusion_timesteps=num_diffusion_timesteps, # StableDiffusion
+ )
+ elif schedule_name == "squaredcos_cap_v2":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+
+ Ported directly from here, and then adapted over time to further experimentation.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
+ :param model_var_type: a ModelVarType determining how variance is output.
+ :param loss_type: a LossType determining the loss function to use.
+ :param rescale_timesteps: if True, pass floating point timesteps into the
+ model so that they are always scaled like in the
+ original paper (0 to 1000).
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type,
+ rescale_timesteps=False,
+ mse_loss_weight_type='constant',
+ noise_offset=0.0,
+ ):
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+ self.rescale_timesteps = rescale_timesteps
+
+ self.mse_loss_weight_type = mse_loss_weight_type
+ self.noise_offset = noise_offset
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ self.alphas_cumprod = alphas_cumprod
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert_shape(self.alphas_cumprod_prev, (self.num_timesteps,))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ # log calculation clipped because the posterior variance is 0 at the
+ # beginning of the diffusion chain.
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
+ ) if len(self.posterior_variance) > 1 else np.array([])
+
+ self.posterior_mean_coef1 = (
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_mean_coef2 = (
+ (1.0 - self.alphas_cumprod_prev)
+ * np.sqrt(alphas)
+ / (1.0 - self.alphas_cumprod)
+ )
+
+ self.sampler = {
+ "ddpm": self.p_sample_loop,
+ "ddim": self.ddim_sample_loop,
+ "plms": self.plms_sample_loop,
+ }
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ )
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
+ )
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+
+ In other words, sample from q(x_t | x_0).
+
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert_shape(noise, x_start)
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+
+ q(x_{t-1} | x_t, x_0)
+
+ """
+ assert_shape(x_start, x_t)
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert_shape(
+ posterior_mean.shape[:1],
+ posterior_variance.shape[:1],
+ posterior_log_variance_clipped.shape[:1],
+ x_start.shape[:1],
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
+ model_var_type=None,
+ ):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param model_var_type: if not None, overlap the default self.model_var_type.
+ It is useful when training with learned var but sampling with fixed var.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ if model_var_type is None:
+ model_var_type = self.model_var_type
+
+ B, C = x.shape[:2]
+ assert_shape(t, (B,))
+ out_dict = model(x, t, **model_kwargs)
+ model_output = out_dict['x']
+
+ if len(out_dict) > 1:
+ extra = {k: v for k, v in out_dict.items() if k != 'x'}
+ else:
+ extra = None
+
+ # self.model_var_type corresponds to model output
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert_shape(model_output, (B, C * 2, *x.shape[2:]))
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+
+ # model_var_type corresponds to reverse diffusion process
+ if model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ if model_var_type == ModelVarType.LEARNED:
+ model_log_variance = model_var_values
+ model_variance = th.exp(model_log_variance)
+ else:
+ min_log = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x.shape
+ )
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
+ )
+ model_mean = model_output
+ elif self.model_mean_type in [
+ ModelMeanType.START_X,
+ ModelMeanType.EPSILON,
+ ModelMeanType.VELOCITY
+ ]:
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ elif self.model_mean_type == ModelMeanType.EPSILON:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_v(x_t=x, t=t, v=model_output)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(
+ x_start=pred_xstart, x_t=x, t=t
+ )
+ else:
+ raise NotImplementedError(self.model_mean_type)
+
+ assert_shape(model_mean, model_log_variance, pred_xstart, x)
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ "extra": extra,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert_shape(x_t, eps)
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_xstart_from_v(self, x_t, t, v):
+ assert_shape(x_t, v)
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
+ assert_shape(x_t, xprev)
+ return ( # (xprev - coef2*x_t) / coef1
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
+ - _extract_into_tensor(
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
+ )
+ * x_t
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _velocity_from_xstart_and_noise(self, x_start, t, noise):
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * noise
+ - _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * x_start
+ )
+
+ def _vb_terms_bpd(
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ ):
+ """
+ Get a term for the variational lower-bound.
+
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )
+ out = self.p_mean_variance(
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
+ )
+ kl = normal_kl(
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
+ )
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert_shape(decoder_nll, x_start)
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = th.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"], "extra": out["extra"]}
+
+ def training_losses(self, model, x_start, step, model_kwargs=None, controlnet=None, noise=None):
+ """
+ Compute training losses for a single timestep.
+
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ # load_rand and dir are used for debug
+ load_rand = False
+ dir = './noise/0814'
+
+ local_rank = th.distributed.get_rank()
+ dev = x_start.device
+ if model_kwargs is None:
+ model_kwargs = {}
+ # Time steps
+ t = th.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device)
+
+ if load_rand:
+ t = th.load(f'{dir}/t_{step}_{local_rank}.pt').to(device=dev)
+ # Noise
+ if noise is None:
+ noise = th.randn_like(x_start)
+ if load_rand:
+ noise = th.load(f'{dir}/noise_{step}_{local_rank}.pt').to(
+ device=dev)
+ if self.noise_offset > 0:
+ # Add channel wise noise offset
+ # https://www.crosslabs.org/blog/diffusion-with-offset-noise
+ noise = noise + self.noise_offset * th.randn(*x_start.shape[:2], 1, 1, device=x_start.device)
+ x_t = self.q_sample(x_start, t, noise=noise)
+
+ terms = {}
+
+ if self.mse_loss_weight_type == 'constant':
+ mse_loss_weight = th.ones_like(t)
+ elif self.mse_loss_weight_type.startswith("min_snr_"):
+ alpha = _extract_into_tensor(self.sqrt_alphas_cumprod, t, t.shape)
+ sigma = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, t.shape)
+ snr = (alpha / sigma) ** 2
+
+ k = float(self.mse_loss_weight_type.split('min_snr_')[-1])
+ # min{snr, k}
+ mse_loss_weight = th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0] / snr
+ else:
+ raise ValueError(self.mse_loss_weight_type)
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ out_dict = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )
+ terms["loss"] = out_dict["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ extra = out_dict["extra"]
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ if controlnet != None:
+ controls = controlnet(x_t, t, **model_kwargs)
+ model_kwargs.pop('condition')
+ model_kwargs.update(controls)
+ out_dict = model(x_t, t, **model_kwargs)
+ model_output = out_dict['x']
+ extra = {k: v for k, v in out_dict.items() if k != 'x'}
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, C = x_t.shape[:2]
+ assert_shape(model_output, (B, C * 2, *x_t.shape[2:]))
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: dict(x=r),
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ if self.model_mean_type == ModelMeanType.VELOCITY:
+ target = self._velocity_from_xstart_and_noise(x_start, t, noise)
+ else:
+ target = {
+ # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
+ # x_start=x_start, x_t=x_t, t=t
+ # )[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ assert_shape(model_output, target, x_start)
+ raw_mse = mean_flat((target - model_output) ** 2).detach()
+ terms["mse"] = mse_loss_weight * mean_flat((target - model_output) ** 2)
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ terms["raw_loss"] = raw_mse + terms["vb"].detach()
+ else:
+ terms["loss"] = terms["mse"]
+ terms["raw_loss"] = raw_mse
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ terms.update(extra)
+ return terms
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, t, **model_kwargs)
+ new_mean = (
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
+ )
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+ See condition_mean() for details on cond_fn.
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
+ x, t, **model_kwargs
+ )
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(
+ x_start=out["pred_xstart"], x_t=x, t=t
+ )
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ model_var_type=None,
+ **kwargs,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param model_var_type: if not None, overlap the default self.model_var_type.
+ It is useful when training with learned var but sampling with fixed var.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ model_var_type=model_var_type,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(
+ cond_fn, out, x, t, model_kwargs=model_kwargs
+ )
+
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ model_var_type=None,
+ device=None,
+ progress=False,
+ progress_leave=True,
+ **kwargs,
+ ):
+ """
+ Generate samples from the model.
+
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param model_var_type: if not None, overlap the default self.model_var_type.
+ It is useful when training with learned var but sampling with fixed var.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ model_var_type=model_var_type,
+ device=device,
+ progress=progress,
+ progress_leave=progress_leave,
+ **kwargs,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ model_var_type=None,
+ device=None,
+ progress=False,
+ progress_leave=True,
+ **kwargs,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices, leave=progress_leave)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ model_var_type=model_var_type,
+ **kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ + th.sqrt(1 - alpha_bar_next) * eps
+ )
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ progress_leave=True,
+ eta=0.0,
+ ):
+ """
+ Generate samples from the model using DDIM.
+
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ progress_leave=progress_leave,
+ eta=eta,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ progress_leave=True,
+ eta=0.0,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices, leave=progress_leave)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ yield out
+ img = out["sample"]
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+
+ This term can't be optimized, as it only depends on the encoder.
+
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+ )
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = th.tensor([t] * batch_size, device=device)
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with th.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
+ mse.append(mean_flat((eps - noise) ** 2))
+
+ vb = th.stack(vb, dim=1)
+ xstart_mse = th.stack(xstart_mse, dim=1)
+ mse = th.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+ def get_eps(
+ self,
+ model,
+ x,
+ t,
+ model_kwargs,
+ cond_fn=None,
+ ):
+ model_output = model(x, t, **model_kwargs)['x']
+ if isinstance(model_output, tuple):
+ model_output, _ = model_output
+ eps = model_output[:, :4]
+ if cond_fn is not None:
+ alpha_bar = _extract_into_tensor_lerp(self.alphas_cumprod, t, x.shape)
+ eps = eps - th.sqrt(1 - alpha_bar) * cond_fn(x, t, **model_kwargs)
+ return eps
+
+ def eps_to_pred_xstart(
+ self,
+ x,
+ eps,
+ t,
+ ):
+ alpha_bar = _extract_into_tensor_lerp(self.alphas_cumprod, t, x.shape)
+ return (x - eps * th.sqrt(1 - alpha_bar)) / th.sqrt(alpha_bar)
+
+ def pndm_transfer(
+ self,
+ x,
+ eps,
+ t_1,
+ t_2,
+ ):
+ pred_xstart = self.eps_to_pred_xstart(x, eps, t_1)
+ alpha_bar_prev = _extract_into_tensor_lerp(self.alphas_cumprod, t_2, x.shape)
+ return pred_xstart * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps
+
+ def prk_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model using PRK.
+
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.prk_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def prk_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Use PRK to sample from the model and yield intermediate samples from
+ each timestep of PRK.
+
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1][1:-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices, leave=False)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.prk_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def prk_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model using fourth-order Pseudo Runge-Kutta
+ (https://openreview.net/forum?id=PlKWVd2yBkY).
+
+ Same usage as p_sample().
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ eps_1 = self.get_eps(model, x, t, model_kwargs, cond_fn)
+ x_1 = self.pndm_transfer(x, eps_1, t, t - 0.5)
+ eps_2 = self.get_eps(model, x_1, t - 0.5, model_kwargs, cond_fn)
+ x_2 = self.pndm_transfer(x, eps_2, t, t - 0.5)
+ eps_3 = self.get_eps(model, x_2, t - 0.5, model_kwargs, cond_fn)
+ x_3 = self.pndm_transfer(x, eps_3, t, t - 1)
+ eps_4 = self.get_eps(model, x_3, t - 1, model_kwargs, cond_fn)
+ eps_prime = (eps_1 + 2 * eps_2 + 2 * eps_3 + eps_4) / 6
+
+ sample = self.pndm_transfer(x, eps_prime, t, t - 1)
+ pred_xstart = self.eps_to_pred_xstart(x, eps_prime, t)
+ pred_xstart = process_xstart(pred_xstart)
+ return {"sample": sample, "pred_xstart": pred_xstart, "eps": eps_prime}
+
+ def plms_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ progress_leave=True,
+ ):
+ """
+ Generate samples from the model using PLMS.
+
+ Same usage as p_sample_loop().
+ """
+ assert self.model_mean_type == ModelMeanType.EPSILON, \
+ 'plms_sample only support model_mean_type == ModelMeanType.EPSILON'
+ final = None
+ for sample in self.plms_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ progress_leave=progress_leave,
+ ):
+ final = sample
+ return final["sample"]
+
+ def plms_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ progress_leave=True,
+ ):
+ """
+ Use PLMS to sample from the model and yield intermediate samples from
+ each timestep of PLMS.
+
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1][1:-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices, leave=progress_leave)
+
+ old_eps = []
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ if len(old_eps) < 3:
+ out = self.prk_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ else:
+ out = self.plms_sample(
+ model,
+ img,
+ old_eps,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ old_eps.pop(0)
+ old_eps.append(out["eps"])
+ yield out
+ img = out["sample"]
+
+ def plms_sample(
+ self,
+ model,
+ x,
+ old_eps,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model using fourth-order Pseudo Linear Multistep
+ (https://openreview.net/forum?id=PlKWVd2yBkY).
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ eps = self.get_eps(model, x, t, model_kwargs, cond_fn)
+ eps_prime = (55 * eps - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ sample = self.pndm_transfer(x, eps_prime, t, t - 1)
+ pred_xstart = self.eps_to_pred_xstart(x, eps, t)
+ pred_xstart = process_xstart(pred_xstart)
+ return {"sample": sample, "pred_xstart": pred_xstart, "eps": eps}
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
+
+
+def _extract_into_tensor_lerp(arr, timesteps, broadcast_shape):
+ """
+ Extract values from arr with fractional time steps
+ """
+ timesteps = timesteps.float()
+ frac = timesteps.frac()
+ while len(frac.shape) < len(broadcast_shape):
+ frac = frac[..., None]
+ res_1 = _extract_into_tensor(arr, timesteps.floor().long(), broadcast_shape)
+ res_2 = _extract_into_tensor(arr, timesteps.ceil().long(), broadcast_shape)
+ return th.lerp(res_1, res_2, frac)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/pipeline.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..e765b7052328628b541ef65a0183f797d1233a57
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/pipeline.py
@@ -0,0 +1,794 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from transformers import BertModel, BertTokenizer
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ..modules.models import HunYuanDiT
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
+ A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
+ unet (Optional[`HunYuanDiT`, `UNet2DConditionModel`]):
+ A `HunYuanDiT` or `UNet2DConditionModel` to denoise the encoded image latents.
+ Notice: Here we still keep the word `unet` for compatibility with the previous version of the pipeline.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: Union[BertModel, CLIPTextModel],
+ tokenizer: Union[BertTokenizer, CLIPTokenizer],
+ unet: Union[HunYuanDiT, UNet2DConditionModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ progress_bar_config: Dict[str, Any] = None,
+ embedder_t5=None,
+ infer_mode='torch',
+ ):
+ super().__init__()
+
+ # ========================================================
+ self.embedder_t5 = embedder_t5
+ self.infer_mode = infer_mode
+
+ # ========================================================
+ if progress_bar_config is None:
+ progress_bar_config = {}
+ if not hasattr(self, '_progress_bar_config'):
+ self._progress_bar_config = {}
+ self._progress_bar_config.update(progress_bar_config)
+ # ========================================================
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ embedder=None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ embedder:
+ T5 embedder (including text encoder and tokenizer)
+ """
+ if embedder is None:
+ text_encoder = self.text_encoder
+ tokenizer = self.tokenizer
+ max_length = self.tokenizer.model_max_length
+ else:
+ text_encoder = embedder.model
+ tokenizer = embedder.tokenizer
+ max_length = embedder.max_length
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(
+ untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+ attention_mask = attention_mask.repeat(num_images_per_prompt, 1)
+ else:
+ attention_mask = None
+
+ if text_encoder is not None:
+ prompt_embeds_dtype = text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ uncond_attention_mask = uncond_input.attention_mask.to(device)
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=uncond_attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+ uncond_attention_mask = uncond_attention_mask.repeat(num_images_per_prompt, 1)
+ else:
+ uncond_attention_mask = None
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ height: int,
+ width: int,
+ prompt: Union[str, List[str]] = None,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_embeds_t5: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ image_meta_size: Optional[torch.LongTensor] = None,
+ style: Optional[torch.LongTensor] = None,
+ progress: bool = True,
+ use_fp16: bool = False,
+ freqs_cis_img: Optional[tuple] = None,
+ learn_sigma: bool = True,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ height (`int`):
+ The height in pixels of the generated image.
+ width (`int`):
+ The width in pixels of the generated image.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor,
+ pred_x0: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+
+ prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask = \
+ self.encode_prompt(prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+ prompt_embeds_t5, negative_prompt_embeds_t5, attention_mask_t5, uncond_attention_mask_t5 = \
+ self.encode_prompt(prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds_t5,
+ negative_prompt_embeds=negative_prompt_embeds_t5,
+ lora_scale=text_encoder_lora_scale,
+ embedder=self.embedder_t5,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ attention_mask = torch.cat([uncond_attention_mask, attention_mask])
+ prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5])
+ attention_mask_t5 = torch.cat([uncond_attention_mask_t5, attention_mask_t5])
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
+ t_expand = torch.tensor([t] * latent_model_input.shape[0], device=latent_model_input.device)
+
+ if use_fp16:
+ latent_model_input = latent_model_input.half()
+ t_expand = t_expand.half()
+ prompt_embeds = prompt_embeds.half()
+ ims = image_meta_size.half() if image_meta_size is not None else None
+ else:
+ ims = image_meta_size if image_meta_size is not None else None
+
+ # predict the noise residual
+ if self.infer_mode in ["fa", "torch"]:
+ noise_pred = self.unet(
+ latent_model_input,
+ t_expand,
+ encoder_hidden_states=prompt_embeds,
+ text_embedding_mask=attention_mask,
+ encoder_hidden_states_t5=prompt_embeds_t5,
+ text_embedding_mask_t5=attention_mask_t5,
+ image_meta_size=ims,
+ style=style,
+ cos_cis_img=freqs_cis_img[0],
+ sin_cis_img=freqs_cis_img[1],
+ return_dict=False,
+ )
+ elif self.infer_mode == "trt":
+ noise_pred = self.unet(
+ x=latent_model_input.contiguous(),
+ t_emb=t_expand.contiguous(),
+ context=prompt_embeds.contiguous(),
+ image_meta_size=ims.contiguous(),
+ style=style.contiguous(),
+ freqs_cis_img0=freqs_cis_img[0].to(device).contiguous(),
+ freqs_cis_img1=freqs_cis_img[1].to(device).contiguous(),
+ text_embedding_mask=attention_mask.contiguous(),
+ encoder_hidden_states_t5=prompt_embeds_t5.contiguous(),
+ text_embedding_mask_t5=attention_mask_t5.contiguous(),
+ )
+ else:
+ raise ValueError("Unknown infer_mode: {self.infer_mode}")
+ if learn_sigma:
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ results = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
+ latents = results.prev_sample
+ pred_x0 = results.pred_original_sample if hasattr(results, 'pred_original_sample') else None
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents, pred_x0)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/pipeline_controlnet.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/pipeline_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e5d0beb8e5dfda19fc9391a856c15468209d85a
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/pipeline_controlnet.py
@@ -0,0 +1,821 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from transformers import BertModel, BertTokenizer
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ..modules.models import HunYuanDiT
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
+ A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
+ unet (Optional[`HunYuanDiT`, `UNet2DConditionModel`]):
+ A `HunYuanDiT` or `UNet2DConditionModel` to denoise the encoded image latents.
+ Notice: Here we still keep the word `unet` for compatibility with the previous version of the pipeline.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: Union[BertModel, CLIPTextModel],
+ tokenizer: Union[BertTokenizer, CLIPTokenizer],
+ unet: Union[HunYuanDiT, UNet2DConditionModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ progress_bar_config: Dict[str, Any] = None,
+ embedder_t5=None,
+ infer_mode='torch',
+ controlnet=None,
+ ):
+ super().__init__()
+
+ # ========================================================
+ self.embedder_t5 = embedder_t5
+ self.infer_mode = infer_mode
+
+ # ========================================================
+ if progress_bar_config is None:
+ progress_bar_config = {}
+ if not hasattr(self, '_progress_bar_config'):
+ self._progress_bar_config = {}
+ self._progress_bar_config.update(progress_bar_config)
+ # ========================================================
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ controlnet=controlnet,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ embedder=None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ embedder:
+ T5 embedder (including text encoder and tokenizer)
+ """
+ if embedder is None:
+ text_encoder = self.text_encoder
+ tokenizer = self.tokenizer
+ max_length = self.tokenizer.model_max_length
+ else:
+ text_encoder = embedder.model
+ tokenizer = embedder.tokenizer
+ max_length = embedder.max_length
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(
+ untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+ attention_mask = attention_mask.repeat(num_images_per_prompt, 1)
+ else:
+ attention_mask = None
+
+ if text_encoder is not None:
+ prompt_embeds_dtype = text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ uncond_attention_mask = uncond_input.attention_mask.to(device)
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=uncond_attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+ uncond_attention_mask = uncond_attention_mask.repeat(num_images_per_prompt, 1)
+ else:
+ uncond_attention_mask = None
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ height: int,
+ width: int,
+ prompt: Union[str, List[str]] = None,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_embeds_t5: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ image_meta_size: Optional[torch.LongTensor] = None,
+ style: Optional[torch.LongTensor] = None,
+ progress: bool = True,
+ use_fp16: bool = False,
+ freqs_cis_img: Optional[tuple] = None,
+ learn_sigma: bool = True,
+ image=None,
+ control_weight=1.0
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ height (`int`):
+ The height in pixels of the generated image.
+ width (`int`):
+ The width in pixels of the generated image.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor,
+ pred_x0: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+
+ prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask = \
+ self.encode_prompt(prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+ prompt_embeds_t5, negative_prompt_embeds_t5, attention_mask_t5, uncond_attention_mask_t5 = \
+ self.encode_prompt(prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds_t5,
+ negative_prompt_embeds=negative_prompt_embeds_t5,
+ lora_scale=text_encoder_lora_scale,
+ embedder=self.embedder_t5,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ attention_mask = torch.cat([uncond_attention_mask, attention_mask])
+ prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5])
+ attention_mask_t5 = torch.cat([uncond_attention_mask_t5, attention_mask_t5])
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ condition = self.vae.encode(image.float()).latent_dist.sample(generator).mul_(self.vae.config.scaling_factor).half()
+ condition = torch.cat([condition] * 2) if do_classifier_free_guidance else condition
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
+ t_expand = torch.tensor([t] * latent_model_input.shape[0], device=latent_model_input.device)
+
+ if use_fp16:
+ latent_model_input = latent_model_input.half()
+ t_expand = t_expand.half()
+ prompt_embeds = prompt_embeds.half()
+ ims = image_meta_size.half() if image_meta_size is not None else None
+ else:
+ ims = image_meta_size if image_meta_size is not None else None
+
+ # predict the noise residual
+ if self.infer_mode in ["fa", "torch"]:
+ controls = self.controlnet(
+ latent_model_input,
+ t_expand,
+ condition,
+ encoder_hidden_states=prompt_embeds,
+ text_embedding_mask=attention_mask,
+ encoder_hidden_states_t5=prompt_embeds_t5,
+ text_embedding_mask_t5=attention_mask_t5,
+ image_meta_size=ims,
+ style=style,
+ cos_cis_img=freqs_cis_img[0],
+ sin_cis_img=freqs_cis_img[1],
+ return_dict=False,
+ )
+ if isinstance(control_weight, list):
+ assert len(control_weight) == len(controls)
+ controls = [control * weight for control, weight in zip(controls, control_weight)]
+ else:
+ controls = [control * control_weight for control in controls]
+ noise_pred = self.unet(
+ latent_model_input,
+ t_expand,
+ encoder_hidden_states=prompt_embeds,
+ text_embedding_mask=attention_mask,
+ encoder_hidden_states_t5=prompt_embeds_t5,
+ text_embedding_mask_t5=attention_mask_t5,
+ image_meta_size=ims,
+ style=style,
+ cos_cis_img=freqs_cis_img[0],
+ sin_cis_img=freqs_cis_img[1],
+ return_dict=False,
+ controls=controls
+ )
+ elif self.infer_mode == "trt":
+ noise_pred = self.unet(
+ x=latent_model_input.contiguous(),
+ t_emb=t_expand.contiguous(),
+ context=prompt_embeds.contiguous(),
+ image_meta_size=ims.contiguous(),
+ style=style.contiguous(),
+ freqs_cis_img0=freqs_cis_img[0].to(device).contiguous(),
+ freqs_cis_img1=freqs_cis_img[1].to(device).contiguous(),
+ text_embedding_mask=attention_mask.contiguous(),
+ encoder_hidden_states_t5=prompt_embeds_t5.contiguous(),
+ text_embedding_mask_t5=attention_mask_t5.contiguous(),
+ )
+ else:
+ raise ValueError("Unknown infer_mode: {self.infer_mode}")
+ if learn_sigma:
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ results = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
+ latents = results.prev_sample
+ pred_x0 = results.pred_original_sample if hasattr(results, 'pred_original_sample') else None
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents, pred_x0)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/respace.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/respace.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e2bf7c0a6bfd78a0295efba8e9f99ef95ca8471
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/diffusion/respace.py
@@ -0,0 +1,144 @@
+import numpy as np
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ Improved DDPM
+
+ A diffusion process which can skip steps in a base diffusion process.
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ def training_losses(
+ self, model, controlnet=None, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ if controlnet != None:
+ return super().training_losses(self._wrap_model(model), controlnet=self._wrap_model(controlnet), *args, **kwargs)
+ else:
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def get_eps(self, model, *args, **kwargs):
+ return super().get_eps(self._wrap_model(model), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(
+ model, self.timestep_map, self.original_num_steps
+ )
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+ """
+ Improved DDPM
+
+ When using a subsequent timesteps (e.g., 250), we must wrap the model
+ for mapping the timesteps 1-250 with step 1 to 1-1000 with step 4
+ """
+ def __init__(self, model, timestep_map, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ # self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ """
+ Here we must make a interpolation because `ts` maybe a float (e.g., 4.5)
+ in the PLMS/PNDM sampler.
+ """
+ ts = ts.float()
+ frac = ts.frac()
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts_1 = map_tensor[ts.floor().long()]
+ new_ts_2 = map_tensor[ts.ceil().long()]
+ new_ts = th.lerp(new_ts_1, new_ts_2, frac)
+ return self.model(x, new_ts, **kwargs)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/ds_config.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/ds_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d108a6de496de852dd21ed755e04e7e7515a89ca
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/ds_config.py
@@ -0,0 +1,140 @@
+# -*- coding: utf-8 -*-
+import os
+
+def deepspeed_config_from_args(args, global_batch_size):
+ if args.use_zero_stage == 2:
+ deepspeed_config = {
+ "train_batch_size": global_batch_size,
+ "train_micro_batch_size_per_gpu": args.batch_size,
+ "gradient_accumulation_steps": args.grad_accu_steps,
+ "steps_per_print": args.log_every,
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": args.lr,
+ "betas": [
+ 0.9,
+ 0.999
+ ],
+ "eps": 1e-08,
+ "weight_decay": args.weight_decay
+ }
+ },
+
+ "zero_optimization": {
+ "stage": 2,
+ "reduce_scatter": False,
+ "reduce_bucket_size": 1e9,
+ },
+
+ "gradient_clipping": 1.0,
+ "prescale_gradients": True,
+
+ "fp16": {
+ "enabled": args.use_fp16,
+ "loss_scale": 0,
+ "loss_scale_window": 500,
+ "hysteresis": 2,
+ "min_loss_scale": 1e-3,
+ "initial_scale_power": 15
+ },
+
+ "bf16": {
+ "enabled": False
+ },
+
+ "wall_clock_breakdown": False
+ }
+
+ elif args.use_zero_stage == 3:
+ deepspeed_config = {
+ "train_batch_size": args.global_batch_size,
+ # "train_micro_batch_size_per_gpu": args.batch_size,
+ "gradient_accumulation_steps": args.grad_accu_steps,
+ "steps_per_print": args.log_every,
+
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": args.lr,
+ "betas": [
+ 0.9,
+ 0.999
+ ],
+ "eps": 1e-08,
+ "weight_decay": args.weight_decay
+ }
+ },
+
+ "zero_optimization": {
+ "stage": 3,
+ "allgather_partitions": True,
+ "overlap_comm": True,
+ "reduce_scatter": True,
+ "contiguous_gradients": True,
+ "stage3_prefetch_bucket_size": 5e8,
+ "stage3_max_live_parameters" : 6e8,
+ "reduce_bucket_size": 1.2e9,
+ "sub_group_size": 1e9,
+ "sub_group_buffer_num": 10,
+ "pipeline_optimizer": True,
+ "max_contigous_event_size": 0,
+ "cache_sub_group_rate": 0.0,
+ "prefetch_cache_sub_group_rate": 1.0,
+ "max_contigous_params_size": -1,
+ "max_param_reduce_events": 0,
+ "stage3_param_persistence_threshold": 9e9,
+ "is_communication_time_profiling": False,
+ "save_large_model_multi_slice": True,
+ "use_fused_op_with_grad_norm_overflow": False,
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": True
+ },
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": True
+ }
+ },
+
+ "gradient_clipping": 1.0,
+ "prescale_gradients": False,
+
+ "fp16": {
+ "enabled": True,
+ "loss_scale": 0,
+ "loss_scale_window": 500,
+ "hysteresis": 2,
+ "min_loss_scale": 1,
+ "initial_scale_power": 15
+ },
+
+ "bf16": {
+ "enabled": False
+ },
+
+ "wall_clock_breakdown": False,
+ "mem_chunk": {
+ "default_chunk_size": 536870911,
+ "use_fake_dist": False,
+ "client": {
+ "mem_tracer": {
+ "use_async_mem_monitor": True,
+ "warmup_gpu_chunk_mem_ratio": 0.8,
+ "overall_gpu_mem_ratio": 0.8,
+ "overall_cpu_mem_ratio": 1.0,
+ "margin_use_ratio": 0.8,
+ "use_fake_dist": False
+ },
+ "opts": {
+ "with_mem_cache": True,
+ "with_async_move": True
+ }
+ }
+ }
+ }
+ else:
+ raise ValueError
+ return deepspeed_config
+
+
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/inference.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..682f29d7faa2d17a567e9f4d623f6eb8e1361181
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/inference.py
@@ -0,0 +1,413 @@
+import random
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+
+# For reproducibility
+# torch.backends.cudnn.benchmark = False
+# torch.backends.cudnn.deterministic = True
+
+from diffusers import schedulers
+from diffusers.models import AutoencoderKL
+from loguru import logger
+from transformers import BertModel, BertTokenizer
+from transformers.modeling_utils import logger as tf_logger
+
+from .constants import SAMPLER_FACTORY, NEGATIVE_PROMPT, TRT_MAX_WIDTH, TRT_MAX_HEIGHT, TRT_MAX_BATCH_SIZE
+from .diffusion.pipeline import StableDiffusionPipeline
+from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
+from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
+from .modules.text_encoder import MT5Embedder
+from .utils.tools import set_seeds
+from peft import LoraConfig
+
+
+class Resolution:
+ def __init__(self, width, height):
+ self.width = width
+ self.height = height
+
+ def __str__(self):
+ return f'{self.height}x{self.width}'
+
+
+class ResolutionGroup:
+ def __init__(self):
+ self.data = [
+ Resolution(1024, 1024), # 1:1
+ Resolution(1280, 1280), # 1:1
+ Resolution(1024, 768), # 4:3
+ Resolution(1152, 864), # 4:3
+ Resolution(1280, 960), # 4:3
+ Resolution(768, 1024), # 3:4
+ Resolution(864, 1152), # 3:4
+ Resolution(960, 1280), # 3:4
+ Resolution(1280, 768), # 16:9
+ Resolution(768, 1280), # 9:16
+ ]
+ self.supported_sizes = set([(r.width, r.height) for r in self.data])
+
+ def is_valid(self, width, height):
+ return (width, height) in self.supported_sizes
+
+
+STANDARD_RATIO = np.array([
+ 1.0, # 1:1
+ 4.0 / 3.0, # 4:3
+ 3.0 / 4.0, # 3:4
+ 16.0 / 9.0, # 16:9
+ 9.0 / 16.0, # 9:16
+])
+STANDARD_SHAPE = [
+ [(1024, 1024), (1280, 1280)], # 1:1
+ [(1280, 960)], # 4:3
+ [(960, 1280)], # 3:4
+ [(1280, 768)], # 16:9
+ [(768, 1280)], # 9:16
+]
+STANDARD_AREA = [
+ np.array([w * h for w, h in shapes])
+ for shapes in STANDARD_SHAPE
+]
+
+
+def get_standard_shape(target_width, target_height):
+ """
+ Map image size to standard size.
+ """
+ target_ratio = target_width / target_height
+ closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
+ closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
+ width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
+ return width, height
+
+
+def _to_tuple(val):
+ if isinstance(val, (list, tuple)):
+ if len(val) == 1:
+ val = [val[0], val[0]]
+ elif len(val) == 2:
+ val = tuple(val)
+ else:
+ raise ValueError(f"Invalid value: {val}")
+ elif isinstance(val, (int, float)):
+ val = (val, val)
+ else:
+ raise ValueError(f"Invalid value: {val}")
+ return val
+
+
+def get_pipeline(args, vae, text_encoder, tokenizer, model, device, rank,
+ embedder_t5, infer_mode, sampler=None):
+ """
+ Get scheduler and pipeline for sampling. The sampler and pipeline are both
+ based on diffusers and make some modifications.
+
+ Returns
+ -------
+ pipeline: StableDiffusionPipeline
+ sampler_name: str
+ """
+ sampler = sampler or args.sampler
+
+ # Load sampler from factory
+ kwargs = SAMPLER_FACTORY[sampler]['kwargs']
+ scheduler = SAMPLER_FACTORY[sampler]['scheduler']
+
+ # Update sampler according to the arguments
+ kwargs['beta_schedule'] = args.noise_schedule
+ kwargs['beta_start'] = args.beta_start
+ kwargs['beta_end'] = args.beta_end
+ kwargs['prediction_type'] = args.predict_type
+
+ # Build scheduler according to the sampler.
+ scheduler_class = getattr(schedulers, scheduler)
+ scheduler = scheduler_class(**kwargs)
+
+ # Set timesteps for inference steps.
+ scheduler.set_timesteps(args.infer_steps, device)
+
+ # Only enable progress bar for rank 0
+ progress_bar_config = {} if rank == 0 else {'disable': True}
+
+ pipeline = StableDiffusionPipeline(vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=model,
+ scheduler=scheduler,
+ feature_extractor=None,
+ safety_checker=None,
+ requires_safety_checker=False,
+ progress_bar_config=progress_bar_config,
+ embedder_t5=embedder_t5,
+ infer_mode=infer_mode,
+ )
+
+ pipeline = pipeline.to(device)
+
+ return pipeline, sampler
+
+
+class End2End(object):
+ def __init__(self, args, models_root_path):
+ self.args = args
+
+ # Check arguments
+ t2i_root_path = Path(models_root_path) / "t2i"
+ self.root = t2i_root_path
+ logger.info(f"Got text-to-image model root path: {t2i_root_path}")
+
+ # Set device and disable gradient
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ torch.set_grad_enabled(False)
+ # Disable BertModel logging checkpoint info
+ tf_logger.setLevel('ERROR')
+
+ # ========================================================================
+ logger.info(f"Loading CLIP Text Encoder...")
+ text_encoder_path = self.root / "clip_text_encoder"
+ self.clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device)
+ logger.info(f"Loading CLIP Text Encoder finished")
+
+ # ========================================================================
+ logger.info(f"Loading CLIP Tokenizer...")
+ tokenizer_path = self.root / "tokenizer"
+ self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
+ logger.info(f"Loading CLIP Tokenizer finished")
+
+ # ========================================================================
+ logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
+ t5_text_encoder_path = self.root / 'mt5'
+ embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256)
+ self.embedder_t5 = embedder_t5
+ logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")
+
+ # ========================================================================
+ logger.info(f"Loading VAE...")
+ vae_path = self.root / "sdxl-vae-fp16-fix"
+ self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device)
+ logger.info(f"Loading VAE finished")
+
+ # ========================================================================
+ # Create model structure and load the checkpoint
+ logger.info(f"Building HunYuan-DiT model...")
+ model_config = HUNYUAN_DIT_CONFIG[self.args.model]
+ self.patch_size = model_config['patch_size']
+ self.head_size = model_config['hidden_size'] // model_config['num_heads']
+ self.resolutions, self.freqs_cis_img = self.standard_shapes() # Used for TensorRT models
+ self.image_size = _to_tuple(self.args.image_size)
+ latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)
+
+ self.infer_mode = self.args.infer_mode
+ if self.infer_mode in ['fa', 'torch']:
+ model_dir = self.root / "model"
+ model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt"
+ if not model_path.exists():
+ raise ValueError(f"model_path not exists: {model_path}")
+ # Build model structure
+ self.model = HunYuanDiT(self.args,
+ input_size=latent_size,
+ **model_config,
+ log_fn=logger.info,
+ ).half().to(self.device) # Force to use fp16
+ # Load model checkpoint
+ logger.info(f"Loading torch model {model_path}...")
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
+ self.model.load_state_dict(state_dict)
+
+ lora_ckpt = args.lora_ckpt
+ if lora_ckpt is not None and lora_ckpt != "":
+ logger.info(f"Loading Lora checkpoint {lora_ckpt}...")
+
+ self.model.load_adapter(lora_ckpt)
+ self.model.merge_and_unload()
+
+
+ self.model.eval()
+ logger.info(f"Loading torch model finished")
+ elif self.infer_mode == 'trt':
+ from .modules.trt.hcf_model import TRTModel
+
+ trt_dir = self.root / "model_trt"
+ engine_dir = trt_dir / "engine"
+ plugin_path = trt_dir / "fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so"
+ model_name = "model_onnx"
+
+ logger.info(f"Loading TensorRT model {engine_dir}/{model_name}...")
+ self.model = TRTModel(model_name=model_name,
+ engine_dir=str(engine_dir),
+ image_height=TRT_MAX_HEIGHT,
+ image_width=TRT_MAX_WIDTH,
+ text_maxlen=args.text_len,
+ embedding_dim=args.text_states_dim,
+ plugin_path=str(plugin_path),
+ max_batch_size=TRT_MAX_BATCH_SIZE,
+ )
+ logger.info(f"Loading TensorRT model finished")
+ else:
+ raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
+
+ # ========================================================================
+ # Build inference pipeline. We use a customized StableDiffusionPipeline.
+ logger.info(f"Loading inference pipeline...")
+ self.pipeline, self.sampler = self.load_sampler()
+ logger.info(f'Loading pipeline finished')
+
+ # ========================================================================
+ self.default_negative_prompt = NEGATIVE_PROMPT
+ logger.info("==================================================")
+ logger.info(f" Model is ready. ")
+ logger.info("==================================================")
+
+ def load_sampler(self, sampler=None):
+ pipeline, sampler = get_pipeline(self.args,
+ self.vae,
+ self.clip_text_encoder,
+ self.tokenizer,
+ self.model,
+ device=self.device,
+ rank=0,
+ embedder_t5=self.embedder_t5,
+ infer_mode=self.infer_mode,
+ sampler=sampler,
+ )
+ return pipeline, sampler
+
+ def calc_rope(self, height, width):
+ th = height // 8 // self.patch_size
+ tw = width // 8 // self.patch_size
+ base_size = 512 // 8 // self.patch_size
+ start, stop = get_fill_resize_and_crop((th, tw), base_size)
+ sub_args = [start, stop, (th, tw)]
+ rope = get_2d_rotary_pos_embed(self.head_size, *sub_args)
+ return rope
+
+ def standard_shapes(self):
+ resolutions = ResolutionGroup()
+ freqs_cis_img = {}
+ for reso in resolutions.data:
+ freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width)
+ return resolutions, freqs_cis_img
+
+ def predict(self,
+ user_prompt,
+ height=1024,
+ width=1024,
+ seed=None,
+ enhanced_prompt=None,
+ negative_prompt=None,
+ infer_steps=100,
+ guidance_scale=6,
+ batch_size=1,
+ src_size_cond=(1024, 1024),
+ sampler=None,
+ ):
+ # ========================================================================
+ # Arguments: seed
+ # ========================================================================
+ if seed is None:
+ seed = random.randint(0, 1_000_000)
+ if not isinstance(seed, int):
+ raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
+ generator = set_seeds(seed, device=self.device)
+ # ========================================================================
+ # Arguments: target_width, target_height
+ # ========================================================================
+ if width <= 0 or height <= 0:
+ raise ValueError(f"`height` and `width` must be positive integers, got height={height}, width={width}")
+ logger.info(f"Input (height, width) = ({height}, {width})")
+ if self.infer_mode in ['fa', 'torch']:
+ # We must force height and width to align to 16 and to be an integer.
+ target_height = int((height // 16) * 16)
+ target_width = int((width // 16) * 16)
+ logger.info(f"Align to 16: (height, width) = ({target_height}, {target_width})")
+ elif self.infer_mode == 'trt':
+ target_width, target_height = get_standard_shape(width, height)
+ logger.info(f"Align to standard shape: (height, width) = ({target_height}, {target_width})")
+ else:
+ raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
+
+ # ========================================================================
+ # Arguments: prompt, new_prompt, negative_prompt
+ # ========================================================================
+ if not isinstance(user_prompt, str):
+ raise TypeError(f"`user_prompt` must be a string, but got {type(user_prompt)}")
+ user_prompt = user_prompt.strip()
+ prompt = user_prompt
+
+ if enhanced_prompt is not None:
+ if not isinstance(enhanced_prompt, str):
+ raise TypeError(f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}")
+ enhanced_prompt = enhanced_prompt.strip()
+ prompt = enhanced_prompt
+
+ # negative prompt
+ if negative_prompt is None or negative_prompt == '':
+ negative_prompt = self.default_negative_prompt
+ if not isinstance(negative_prompt, str):
+ raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}")
+
+ # ========================================================================
+ # Arguments: style. (A fixed argument. Don't Change it.)
+ # ========================================================================
+ style = torch.as_tensor([0, 0] * batch_size, device=self.device)
+
+ # ========================================================================
+ # Inner arguments: image_meta_size (Please refer to SDXL.)
+ # ========================================================================
+ if isinstance(src_size_cond, int):
+ src_size_cond = [src_size_cond, src_size_cond]
+ if not isinstance(src_size_cond, (list, tuple)):
+ raise TypeError(f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}")
+ if len(src_size_cond) != 2:
+ raise ValueError(f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}")
+ size_cond = list(src_size_cond) + [target_width, target_height, 0, 0]
+ image_meta_size = torch.as_tensor([size_cond] * 2 * batch_size, device=self.device)
+
+ # ========================================================================
+ start_time = time.time()
+ logger.debug(f"""
+ prompt: {user_prompt}
+ enhanced prompt: {enhanced_prompt}
+ seed: {seed}
+ (height, width): {(target_height, target_width)}
+ negative_prompt: {negative_prompt}
+ batch_size: {batch_size}
+ guidance_scale: {guidance_scale}
+ infer_steps: {infer_steps}
+ image_meta_size: {size_cond}
+ """)
+ reso = f'{target_height}x{target_width}'
+ if reso in self.freqs_cis_img:
+ freqs_cis_img = self.freqs_cis_img[reso]
+ else:
+ freqs_cis_img = self.calc_rope(target_height, target_width)
+
+ if sampler is not None and sampler != self.sampler:
+ self.pipeline, self.sampler = self.load_sampler(sampler)
+
+ samples = self.pipeline(
+ height=target_height,
+ width=target_width,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=batch_size,
+ guidance_scale=guidance_scale,
+ num_inference_steps=infer_steps,
+ image_meta_size=image_meta_size,
+ style=style,
+ return_dict=False,
+ generator=generator,
+ freqs_cis_img=freqs_cis_img,
+ use_fp16=self.args.use_fp16,
+ learn_sigma=self.args.learn_sigma,
+ )[0]
+ gen_time = time.time() - start_time
+ logger.debug(f"Success, time: {gen_time}")
+
+ return {
+ 'images': samples,
+ 'seed': seed,
+ }
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/inference_controlnet.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/inference_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..0aa161f19d0a4e4e8f1e8c664c75df930baf9280
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/inference_controlnet.py
@@ -0,0 +1,434 @@
+import random
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+
+# For reproducibility
+# torch.backends.cudnn.benchmark = False
+# torch.backends.cudnn.deterministic = True
+
+from diffusers import schedulers
+from diffusers.models import AutoencoderKL
+from loguru import logger
+from transformers import BertModel, BertTokenizer
+from transformers.modeling_utils import logger as tf_logger
+
+from .constants import SAMPLER_FACTORY, NEGATIVE_PROMPT, TRT_MAX_WIDTH, TRT_MAX_HEIGHT, TRT_MAX_BATCH_SIZE
+from .diffusion.pipeline_controlnet import StableDiffusionControlNetPipeline
+from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
+from .modules.controlnet import HunYuanControlNet
+from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
+from .modules.text_encoder import MT5Embedder
+from .utils.tools import set_seeds
+from peft import LoraConfig
+
+
+class Resolution:
+ def __init__(self, width, height):
+ self.width = width
+ self.height = height
+
+ def __str__(self):
+ return f'{self.height}x{self.width}'
+
+
+class ResolutionGroup:
+ def __init__(self):
+ self.data = [
+ Resolution(1024, 1024), # 1:1
+ Resolution(1280, 1280), # 1:1
+ Resolution(1024, 768), # 4:3
+ Resolution(1152, 864), # 4:3
+ Resolution(1280, 960), # 4:3
+ Resolution(768, 1024), # 3:4
+ Resolution(864, 1152), # 3:4
+ Resolution(960, 1280), # 3:4
+ Resolution(1280, 768), # 16:9
+ Resolution(768, 1280), # 9:16
+ ]
+ self.supported_sizes = set([(r.width, r.height) for r in self.data])
+
+ def is_valid(self, width, height):
+ return (width, height) in self.supported_sizes
+
+
+STANDARD_RATIO = np.array([
+ 1.0, # 1:1
+ 4.0 / 3.0, # 4:3
+ 3.0 / 4.0, # 3:4
+ 16.0 / 9.0, # 16:9
+ 9.0 / 16.0, # 9:16
+])
+STANDARD_SHAPE = [
+ [(1024, 1024), (1280, 1280)], # 1:1
+ [(1280, 960)], # 4:3
+ [(960, 1280)], # 3:4
+ [(1280, 768)], # 16:9
+ [(768, 1280)], # 9:16
+]
+STANDARD_AREA = [
+ np.array([w * h for w, h in shapes])
+ for shapes in STANDARD_SHAPE
+]
+
+
+def get_standard_shape(target_width, target_height):
+ """
+ Map image size to standard size.
+ """
+ target_ratio = target_width / target_height
+ closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
+ closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
+ width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
+ return width, height
+
+
+def _to_tuple(val):
+ if isinstance(val, (list, tuple)):
+ if len(val) == 1:
+ val = [val[0], val[0]]
+ elif len(val) == 2:
+ val = tuple(val)
+ else:
+ raise ValueError(f"Invalid value: {val}")
+ elif isinstance(val, (int, float)):
+ val = (val, val)
+ else:
+ raise ValueError(f"Invalid value: {val}")
+ return val
+
+
+def get_pipeline(args, vae, text_encoder, tokenizer, model, device, rank,
+ embedder_t5, infer_mode, controlnet, sampler=None):
+ """
+ Get scheduler and pipeline for sampling. The sampler and pipeline are both
+ based on diffusers and make some modifications.
+
+ Returns
+ -------
+ pipeline: StableDiffusionControlNetPipeline
+ sampler_name: str
+ """
+ sampler = sampler or args.sampler
+
+ # Load sampler from factory
+ kwargs = SAMPLER_FACTORY[sampler]['kwargs']
+ scheduler = SAMPLER_FACTORY[sampler]['scheduler']
+
+ # Update sampler according to the arguments
+ kwargs['beta_schedule'] = args.noise_schedule
+ kwargs['beta_start'] = args.beta_start
+ kwargs['beta_end'] = args.beta_end
+ kwargs['prediction_type'] = args.predict_type
+
+ # Build scheduler according to the sampler.
+ scheduler_class = getattr(schedulers, scheduler)
+ scheduler = scheduler_class(**kwargs)
+
+ # Set timesteps for inference steps.
+ scheduler.set_timesteps(args.infer_steps, device)
+
+ # Only enable progress bar for rank 0
+ progress_bar_config = {} if rank == 0 else {'disable': True}
+
+ pipeline = StableDiffusionControlNetPipeline(vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=model,
+ scheduler=scheduler,
+ feature_extractor=None,
+ safety_checker=None,
+ requires_safety_checker=False,
+ progress_bar_config=progress_bar_config,
+ embedder_t5=embedder_t5,
+ infer_mode=infer_mode,
+ controlnet=controlnet
+ )
+
+ pipeline = pipeline.to(device)
+
+ return pipeline, sampler
+
+
+class End2End(object):
+ def __init__(self, args, models_root_path):
+ self.args = args
+
+ # Check arguments
+ t2i_root_path = Path(models_root_path) / "t2i"
+ self.root = t2i_root_path
+ logger.info(f"Got text-to-image model root path: {t2i_root_path}")
+
+ # Set device and disable gradient
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ torch.set_grad_enabled(False)
+ # Disable BertModel logging checkpoint info
+ tf_logger.setLevel('ERROR')
+
+ # ========================================================================
+ logger.info(f"Loading CLIP Text Encoder...")
+ text_encoder_path = self.root / "clip_text_encoder"
+ self.clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device)
+ logger.info(f"Loading CLIP Text Encoder finished")
+
+ # ========================================================================
+ logger.info(f"Loading CLIP Tokenizer...")
+ tokenizer_path = self.root / "tokenizer"
+ self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
+ logger.info(f"Loading CLIP Tokenizer finished")
+
+ # ========================================================================
+ logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
+ t5_text_encoder_path = self.root / 'mt5'
+ embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256)
+ self.embedder_t5 = embedder_t5
+ logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")
+
+ # ========================================================================
+ logger.info(f"Loading VAE...")
+ vae_path = self.root / "sdxl-vae-fp16-fix"
+ self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device)
+ logger.info(f"Loading VAE finished")
+
+ # ========================================================================
+ # Create model structure and load the checkpoint
+ logger.info(f"Building HunYuan-DiT model...")
+ model_config = HUNYUAN_DIT_CONFIG[self.args.model]
+ self.patch_size = model_config['patch_size']
+ self.head_size = model_config['hidden_size'] // model_config['num_heads']
+ self.resolutions, self.freqs_cis_img = self.standard_shapes() # Used for TensorRT models
+ self.image_size = _to_tuple(self.args.image_size)
+ latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)
+
+ self.infer_mode = self.args.infer_mode
+ if self.infer_mode in ['fa', 'torch']:
+ model_dir = self.root / "model"
+ model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt"
+ if not model_path.exists():
+ raise ValueError(f"model_path not exists: {model_path}")
+ # Build model structure
+ controlnet_dir = self.root / "controlnet"
+ controlnet_path = controlnet_dir / f"pytorch_model_{self.args.control_type}_{self.args.load_key}.pt"
+ if not controlnet_path.exists():
+ raise ValueError(f"controlnet_path not exists: {controlnet_path}")
+ self.model = HunYuanDiT(self.args,
+ input_size=latent_size,
+ **model_config,
+ log_fn=logger.info,
+ ).half().to(self.device) # Force to use fp16
+ self.controlnet = HunYuanControlNet(self.args,
+ input_size=latent_size,
+ **model_config,
+ log_fn=logger.info,
+ ).half().to(self.device)
+ controlnet_state_dict = torch.load(controlnet_path)
+ if 'module' in controlnet_state_dict:
+ controlnet_state_dict = controlnet_state_dict['module']
+ self.controlnet.load_state_dict(controlnet_state_dict)
+ logger.info(f"Loading controlnet finished")
+ # Load model checkpoint
+ logger.info(f"Loading torch model {model_path}...")
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
+ self.model.load_state_dict(state_dict)
+
+ lora_ckpt = args.lora_ckpt
+ if lora_ckpt is not None and lora_ckpt != "":
+ logger.info(f"Loading Lora checkpoint {lora_ckpt}...")
+
+ self.model.load_adapter(lora_ckpt)
+ self.model.merge_and_unload()
+
+
+ self.model.eval()
+ self.controlnet.eval()
+ logger.info(f"Loading torch model finished")
+ elif self.infer_mode == 'trt':
+ from .modules.trt.hcf_model import TRTModel
+
+ trt_dir = self.root / "model_trt"
+ engine_dir = trt_dir / "engine"
+ plugin_path = trt_dir / "fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so"
+ model_name = "model_onnx"
+
+ logger.info(f"Loading TensorRT model {engine_dir}/{model_name}...")
+ self.model = TRTModel(model_name=model_name,
+ engine_dir=str(engine_dir),
+ image_height=TRT_MAX_HEIGHT,
+ image_width=TRT_MAX_WIDTH,
+ text_maxlen=args.text_len,
+ embedding_dim=args.text_states_dim,
+ plugin_path=str(plugin_path),
+ max_batch_size=TRT_MAX_BATCH_SIZE,
+ )
+ logger.info(f"Loading TensorRT model finished")
+ else:
+ raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
+
+ # ========================================================================
+ # Build inference pipeline. We use a customized StableDiffusionControlNetPipeline.
+ logger.info(f"Loading inference pipeline...")
+ self.pipeline, self.sampler = self.load_sampler()
+ logger.info(f'Loading pipeline finished')
+
+ # ========================================================================
+ self.default_negative_prompt = NEGATIVE_PROMPT
+ logger.info("==================================================")
+ logger.info(f" Model is ready. ")
+ logger.info("==================================================")
+
+ def load_sampler(self, sampler=None):
+ pipeline, sampler = get_pipeline(self.args,
+ self.vae,
+ self.clip_text_encoder,
+ self.tokenizer,
+ self.model,
+ device=self.device,
+ rank=0,
+ embedder_t5=self.embedder_t5,
+ infer_mode=self.infer_mode,
+ sampler=sampler,
+ controlnet=self.controlnet
+ )
+ return pipeline, sampler
+
+ def calc_rope(self, height, width):
+ th = height // 8 // self.patch_size
+ tw = width // 8 // self.patch_size
+ base_size = 512 // 8 // self.patch_size
+ start, stop = get_fill_resize_and_crop((th, tw), base_size)
+ sub_args = [start, stop, (th, tw)]
+ rope = get_2d_rotary_pos_embed(self.head_size, *sub_args)
+ return rope
+
+ def standard_shapes(self):
+ resolutions = ResolutionGroup()
+ freqs_cis_img = {}
+ for reso in resolutions.data:
+ freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width)
+ return resolutions, freqs_cis_img
+
+ def predict(self,
+ user_prompt,
+ image,
+ height=1024,
+ width=1024,
+ seed=None,
+ enhanced_prompt=None,
+ negative_prompt=None,
+ infer_steps=100,
+ guidance_scale=6,
+ batch_size=1,
+ src_size_cond=(1024, 1024),
+ sampler=None,
+ ):
+ # ========================================================================
+ # Arguments: seed
+ # ========================================================================
+ if seed is None:
+ seed = random.randint(0, 1_000_000)
+ if not isinstance(seed, int):
+ raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
+ generator = set_seeds(seed, device=self.device)
+ # ========================================================================
+ # Arguments: target_width, target_height
+ # ========================================================================
+ if width <= 0 or height <= 0:
+ raise ValueError(f"`height` and `width` must be positive integers, got height={height}, width={width}")
+ logger.info(f"Input (height, width) = ({height}, {width})")
+ if self.infer_mode in ['fa', 'torch']:
+ # We must force height and width to align to 16 and to be an integer.
+ target_height = int((height // 16) * 16)
+ target_width = int((width // 16) * 16)
+ logger.info(f"Align to 16: (height, width) = ({target_height}, {target_width})")
+ elif self.infer_mode == 'trt':
+ target_width, target_height = get_standard_shape(width, height)
+ logger.info(f"Align to standard shape: (height, width) = ({target_height}, {target_width})")
+ else:
+ raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
+
+ # ========================================================================
+ # Arguments: prompt, new_prompt, negative_prompt
+ # ========================================================================
+ if not isinstance(user_prompt, str):
+ raise TypeError(f"`user_prompt` must be a string, but got {type(user_prompt)}")
+ user_prompt = user_prompt.strip()
+ prompt = user_prompt
+
+ if enhanced_prompt is not None:
+ if not isinstance(enhanced_prompt, str):
+ raise TypeError(f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}")
+ enhanced_prompt = enhanced_prompt.strip()
+ prompt = enhanced_prompt
+
+ # negative prompt
+ if negative_prompt is None or negative_prompt == '':
+ negative_prompt = self.default_negative_prompt
+ if not isinstance(negative_prompt, str):
+ raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}")
+
+ # ========================================================================
+ # Arguments: style. (A fixed argument. Don't Change it.)
+ # ========================================================================
+ style = torch.as_tensor([0, 0] * batch_size, device=self.device)
+
+ # ========================================================================
+ # Inner arguments: image_meta_size (Please refer to SDXL.)
+ # ========================================================================
+ if isinstance(src_size_cond, int):
+ src_size_cond = [src_size_cond, src_size_cond]
+ if not isinstance(src_size_cond, (list, tuple)):
+ raise TypeError(f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}")
+ if len(src_size_cond) != 2:
+ raise ValueError(f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}")
+ size_cond = list(src_size_cond) + [target_width, target_height, 0, 0]
+ image_meta_size = torch.as_tensor([size_cond] * 2 * batch_size, device=self.device)
+
+ # ========================================================================
+ start_time = time.time()
+ logger.debug(f"""
+ prompt: {user_prompt}
+ enhanced prompt: {enhanced_prompt}
+ seed: {seed}
+ (height, width): {(target_height, target_width)}
+ negative_prompt: {negative_prompt}
+ batch_size: {batch_size}
+ guidance_scale: {guidance_scale}
+ infer_steps: {infer_steps}
+ image_meta_size: {size_cond}
+ """)
+ reso = f'{target_height}x{target_width}'
+ if reso in self.freqs_cis_img:
+ freqs_cis_img = self.freqs_cis_img[reso]
+ else:
+ freqs_cis_img = self.calc_rope(target_height, target_width)
+
+ if sampler is not None and sampler != self.sampler:
+ self.pipeline, self.sampler = self.load_sampler(sampler)
+
+ samples = self.pipeline(
+ height=target_height,
+ width=target_width,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=batch_size,
+ guidance_scale=guidance_scale,
+ num_inference_steps=infer_steps,
+ image_meta_size=image_meta_size,
+ style=style,
+ return_dict=False,
+ generator=generator,
+ freqs_cis_img=freqs_cis_img,
+ use_fp16=self.args.use_fp16,
+ learn_sigma=self.args.learn_sigma,
+ image=image,
+ control_weight=eval(self.args.control_weight),
+ )[0]
+ gen_time = time.time() - start_time
+ logger.debug(f"Success, time: {gen_time}")
+
+ return {
+ 'images': samples,
+ 'seed': seed,
+ }
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/lr_scheduler.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..027c7e6b42eaa8bfeac48a2405d479869bdda683
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/lr_scheduler.py
@@ -0,0 +1,761 @@
+"""
+Implementation of learning rate schedules.
+
+Taken and modified from PyTorch v1.0.1 source
+https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py
+"""
+
+import argparse
+from torch.optim import Optimizer
+import math
+
+LR_SCHEDULE = 'lr_schedule'
+LR_RANGE_TEST = 'LRRangeTest'
+ONE_CYCLE = 'OneCycle'
+WARMUP_LR = 'WarmupLR'
+WARMUP_DECAY_LR = 'WarmupDecayLR'
+VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR]
+
+LR_RANGE_TEST_MIN_LR = 'lr_range_test_min_lr'
+LR_RANGE_TEST_STEP_RATE = 'lr_range_test_step_rate'
+LR_RANGE_TEST_STEP_SIZE = 'lr_range_test_step_size'
+LR_RANGE_TEST_STAIRCASE = 'lr_range_test_staircase'
+
+EDGE_VALUE = 'edge_value'
+MID_VALUE = 'mid_value'
+
+CYCLE_FIRST_STEP_SIZE = 'cycle_first_step_size'
+CYCLE_FIRST_STAIR_COUNT = 'cycle_first_stair_count'
+CYCLE_SECOND_STEP_SIZE = 'cycle_second_step_size'
+CYCLE_SECOND_STAIR_COUNT = 'cycle_second_stair_count'
+DECAY_STEP_SIZE = 'decay_step_size'
+
+CYCLE_MIN_LR = 'cycle_min_lr'
+CYCLE_MAX_LR = 'cycle_max_lr'
+DECAY_LR_RATE = 'decay_lr_rate'
+
+CYCLE_MIN_MOM = 'cycle_min_mom'
+CYCLE_MAX_MOM = 'cycle_max_mom'
+DECAY_MOM_RATE = 'decay_mom_rate'
+
+WARMUP_MIN_LR = 'warmup_min_lr'
+WARMUP_MAX_LR = 'warmup_max_lr'
+WARMUP_NUM_STEPS = 'warmup_num_steps'
+WARMUP_TYPE = 'warmup_type'
+WARMUP_LOG_RATE = 'log'
+WARMUP_LINEAR_RATE = 'linear'
+
+TOTAL_NUM_STEPS = 'total_num_steps'
+
+
+def add_tuning_arguments(parser):
+ group = parser.add_argument_group('Convergence Tuning', 'Convergence tuning configurations')
+
+ # LR scheduler
+ group.add_argument('--lr_schedule', type=str, default=None, help='LR schedule for training.')
+
+ # Learning rate range test
+ group.add_argument("--lr_range_test_min_lr", type=float, default=0.001, help='Starting lr value.')
+ group.add_argument("--lr_range_test_step_rate", type=float, default=1.0, help='scaling rate for LR range test.')
+ group.add_argument("--lr_range_test_step_size", type=int, default=1000, help='training steps per LR change.')
+ group.add_argument("--lr_range_test_staircase",
+ type=bool,
+ default=False,
+ help='use staircase scaling for LR range test.')
+
+ # OneCycle schedule
+ group.add_argument("--cycle_first_step_size",
+ type=int,
+ default=1000,
+ help='size of first step of 1Cycle schedule (training steps).')
+ group.add_argument("--cycle_first_stair_count",
+ type=int,
+ default=-1,
+ help='first stair count for 1Cycle schedule.')
+ group.add_argument("--cycle_second_step_size",
+ type=int,
+ default=-1,
+ help='size of second step of 1Cycle schedule (default first_step_size).')
+ group.add_argument("--cycle_second_stair_count",
+ type=int,
+ default=-1,
+ help='second stair count for 1Cycle schedule.')
+ group.add_argument("--decay_step_size",
+ type=int,
+ default=1000,
+ help='size of intervals for applying post cycle decay (training steps).')
+
+ # 1Cycle LR
+ group.add_argument("--cycle_min_lr", type=float, default=0.01, help='1Cycle LR lower bound.')
+ group.add_argument("--cycle_max_lr", type=float, default=0.1, help='1Cycle LR upper bound.')
+ group.add_argument("--decay_lr_rate", type=float, default=0.0, help='post cycle LR decay rate.')
+
+ # 1Cycle Momentum
+ group.add_argument('--cycle_momentum', default=False, action='store_true', help='Enable 1Cycle momentum schedule.')
+ group.add_argument("--cycle_min_mom", type=float, default=0.8, help='1Cycle momentum lower bound.')
+ group.add_argument("--cycle_max_mom", type=float, default=0.9, help='1Cycle momentum upper bound.')
+ group.add_argument("--decay_mom_rate", type=float, default=0.0, help='post cycle momentum decay rate.')
+
+ # Warmup LR
+ group.add_argument('--warmup_min_lr', type=float, default=0, help='WarmupLR minimum/initial LR value')
+ group.add_argument('--warmup_max_lr', type=float, default=0.001, help='WarmupLR maximum LR value.')
+ group.add_argument('--warmup_num_steps', type=int, default=1000, help='WarmupLR step count for LR warmup.')
+ group.add_argument('--warmup_type',
+ type=str,
+ default=WARMUP_LOG_RATE,
+ help='WarmupLR increasing function during warmup')
+ return parser
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser = add_tuning_arguments(parser)
+
+ lr_sched_args, unknown_args = parser.parse_known_args()
+ return lr_sched_args, unknown_args
+
+
+def override_lr_range_test_params(args, params):
+ if hasattr(args, LR_RANGE_TEST_MIN_LR) and args.lr_range_test_min_lr is not None:
+ params[LR_RANGE_TEST_MIN_LR] = args.lr_range_test_min_lr
+
+ if hasattr(args, LR_RANGE_TEST_STEP_RATE) and args.lr_range_test_step_rate is not None:
+ params[LR_RANGE_TEST_STEP_RATE] = args.lr_range_test_step_rate
+
+ if hasattr(args, LR_RANGE_TEST_STEP_SIZE) and args.lr_range_test_step_size is not None:
+ params[LR_RANGE_TEST_STEP_SIZE] = args.lr_range_test_step_size
+
+ if hasattr(args, LR_RANGE_TEST_STAIRCASE) and args.lr_range_test_staircase is not None:
+ params[LR_RANGE_TEST_STAIRCASE] = args.lr_range_test_staircase
+
+
+def override_1cycle_params(args, params):
+ if hasattr(args, CYCLE_FIRST_STEP_SIZE) and args.cycle_first_step_size is not None:
+ params[CYCLE_FIRST_STEP_SIZE] = args.cycle_first_step_size
+
+ if hasattr(args, CYCLE_FIRST_STAIR_COUNT) and args.cycle_first_stair_count is not None:
+ params[CYCLE_FIRST_STAIR_COUNT] = args.cycle_first_stair_count
+
+ if hasattr(args, CYCLE_SECOND_STEP_SIZE) and args.cycle_second_step_size is not None:
+ params[CYCLE_SECOND_STEP_SIZE] = args.cycle_second_step_size
+
+ if hasattr(args, CYCLE_SECOND_STAIR_COUNT) and args.cycle_second_stair_count is not None:
+ params[CYCLE_SECOND_STAIR_COUNT] = args.cycle_second_stair_count
+
+ if hasattr(args, DECAY_STEP_SIZE) and args.decay_step_size is not None:
+ params[DECAY_STEP_SIZE] = args.decay_step_size
+
+ # 1Cycle LR params
+ if hasattr(args, CYCLE_MIN_LR) and args.cycle_min_lr is not None:
+ params[CYCLE_MIN_LR] = args.cycle_min_lr
+
+ if hasattr(args, CYCLE_MAX_LR) and args.cycle_max_lr is not None:
+ params[CYCLE_MAX_LR] = args.cycle_max_lr
+
+ if hasattr(args, DECAY_LR_RATE) and args.decay_lr_rate is not None:
+ params[DECAY_LR_RATE] = args.decay_lr_rate
+
+ # 1Cycle MOM params
+ if hasattr(args, CYCLE_MIN_MOM) and args.cycle_min_mom is not None:
+ params[CYCLE_MIN_MOM] = args.cycle_min_mom
+
+ if hasattr(args, CYCLE_MAX_MOM) and args.cycle_max_mom is not None:
+ params[CYCLE_MAX_MOM] = args.cycle_max_mom
+
+ if hasattr(args, DECAY_MOM_RATE) and args.decay_mom_rate is not None:
+ params[DECAY_MOM_RATE] = args.decay_mom_rate
+
+
+def override_warmupLR_params(args, params):
+ if hasattr(args, WARMUP_MIN_LR) and args.warmup_min_lr is not None:
+ params[WARMUP_MIN_LR] = args.warmup_min_lr
+
+ if hasattr(args, WARMUP_MAX_LR) and args.warmup_max_lr is not None:
+ params[WARMUP_MAX_LR] = args.warmup_max_lr
+
+ if hasattr(args, WARMUP_NUM_STEPS) and args.warmup_num_steps is not None:
+ params[WARMUP_NUM_STEPS] = args.warmup_num_steps
+
+ if hasattr(args, WARMUP_TYPE) and args.warmup_type is not None:
+ params[WARMUP_TYPE] = args.warmup_type
+
+
+def override_params(args, params):
+ # LR range test params
+ override_lr_range_test_params(args, params)
+
+ # 1Cycle params
+ override_1cycle_params(args, params)
+
+ # WarmupLR params
+ override_warmupLR_params(args, params)
+
+
+def get_config_from_args(args):
+ if not hasattr(args, LR_SCHEDULE) or args.lr_schedule is None:
+ return None, '--{} not specified on command line'.format(LR_SCHEDULE)
+
+ if not args.lr_schedule in VALID_LR_SCHEDULES:
+ return None, '{} is not supported LR schedule'.format(args.lr_schedule)
+
+ config = {}
+ config['type'] = args.lr_schedule
+ config['params'] = {}
+
+ if args.lr_schedule == LR_RANGE_TEST:
+ override_lr_range_test_params(args, config['params'])
+ elif args.lr_schedule == ONE_CYCLE:
+ override_1cycle_params(args, config['params'])
+ else:
+ override_warmupLR_params(args, config['params'])
+
+ return config, None
+
+
+def get_lr_from_config(config):
+ if not 'type' in config:
+ return None, 'LR schedule type not defined in config'
+
+ if not 'params' in config:
+ return None, 'LR schedule params not defined in config'
+
+ lr_schedule = config['type']
+ lr_params = config['params']
+
+ if not lr_schedule in VALID_LR_SCHEDULES:
+ return None, '{} is not a valid LR schedule'.format(lr_schedule)
+
+ if lr_schedule == LR_RANGE_TEST:
+ return lr_params[LR_RANGE_TEST_MIN_LR], ''
+ if lr_schedule == ONE_CYCLE:
+ return lr_params[CYCLE_MAX_LR], ''
+ # Warmup LR
+ return lr_params[WARMUP_MAX_LR], ''
+
+
+"""
+Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped
+optimizer to see if requirement is satisfied.
+TODO: Looking under the hood to examine the wrapped optimizer is a hack that requires a better long-term fix.
+"""
+
+
+def get_torch_optimizer(optimizer):
+ if isinstance(optimizer, Optimizer):
+ return optimizer
+
+ if hasattr(optimizer, 'optimizer') and isinstance(optimizer.optimizer, Optimizer):
+ return optimizer.optimizer
+
+ raise TypeError('{} is not a subclass of torch.optim.Optimizer'.format(type(optimizer).__name__))
+
+
+class LRRangeTest(object):
+ """Sets the learning rate of each parameter group according to
+ learning rate range test (LRRT) policy. The policy increases learning
+ rate starting from a base value with a constant frequency, as detailed in
+ the paper `A disciplined approach to neural network hyper-parameters: Part1`_.
+
+ LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to
+ configure the LR boundaries for Cyclic LR schedules.
+
+ LRRT changes the learning rate after every batch.
+ `step` should be called after a batch has been used for training.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ lr_range_test_min_lr (float or list): Initial learning rate which is the
+ lower boundary in the range test for each parameter group.
+ lr_range_test_step_size (int): Interval of training steps to increase learning rate. Default: 2000
+ lr_range_test_step_rate (float): Scaling rate for range test. Default: 1.0
+ lr_range_test_staircase (bool): Scale in staircase fashion, rather than continuous. Default: False.
+ last_batch_iteration (int): The index of the last batch. This parameter is used when
+ resuming a training job. Since `step()` should be invoked after each
+ batch instead of after each epoch, this number represents the total
+ number of *batches* computed, not the total number of epochs computed.
+ When last_batch_iteration=-1, the schedule is started from the beginning.
+ Default: -1
+
+ Example:
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
+ >>> scheduler = LRRangeTest(optimizer)
+ >>> data_loader = torch.utils.data.DataLoader(...)
+ >>> for epoch in range(10):
+ >>> for batch in data_loader:
+ >>> train_batch(...)
+ >>> scheduler.step()
+
+ _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay:
+ https://arxiv.org/abs/1803.09820
+"""
+
+ def __init__(self,
+ optimizer: Optimizer,
+ lr_range_test_min_lr: float = 1e-3,
+ lr_range_test_step_size: int = 2000,
+ lr_range_test_step_rate: float = 1.0,
+ lr_range_test_staircase: bool = False,
+ last_batch_iteration: int = -1):
+
+ self.optimizer = get_torch_optimizer(optimizer)
+
+ if isinstance(lr_range_test_min_lr, list) or isinstance(lr_range_test_min_lr, tuple):
+ if len(lr_range_test_min_lr) != len(self.optimizer.param_groups):
+ raise ValueError("expected {} lr_range_test_min_lr, got {}".format(len(self.optimizer.param_groups),
+ len(lr_range_test_min_lr)))
+ self.min_lr = list(lr_range_test_min_lr)
+ else:
+ self.min_lr = [lr_range_test_min_lr] * len(self.optimizer.param_groups)
+
+ self.step_size = lr_range_test_step_size
+ self.step_rate = lr_range_test_step_rate
+ self.last_batch_iteration = last_batch_iteration
+ self.staircase = lr_range_test_staircase
+ self.interval_fn = self._staircase_interval if lr_range_test_staircase else self._continuous_interval
+
+ if last_batch_iteration == -1:
+ self._update_optimizer(self.min_lr)
+
+ def _staircase_interval(self):
+ return math.floor(float(self.last_batch_iteration + 1) / self.step_size)
+
+ def _continuous_interval(self):
+ return float(self.last_batch_iteration + 1) / self.step_size
+
+ def _get_increase(self):
+ return (1 + self.step_rate * self.interval_fn())
+
+ def get_lr(self):
+ lr_increase = self._get_increase()
+ return [lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr]
+
+ def get_last_lr(self):
+ """ Return last computed learning rate by current scheduler.
+ """
+ assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
+ return self._last_lr
+
+ def _update_optimizer(self, group_lrs):
+ for param_group, lr in zip(self.optimizer.param_groups, group_lrs):
+ param_group['lr'] = lr
+
+ def step(self, batch_iteration=None):
+ if batch_iteration is None:
+ batch_iteration = self.last_batch_iteration + 1
+ self.last_batch_iteration = batch_iteration
+ self._update_optimizer(self.get_lr())
+ self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
+
+ def state_dict(self):
+ return {'last_batch_iteration': self.last_batch_iteration}
+
+ def load_state_dict(self, sd):
+ self.last_batch_iteration = sd['last_batch_iteration']
+
+
+
+class OneCycle(object):
+ """Sets the learning rate of each parameter group according to
+ 1Cycle learning rate policy (1CLR). 1CLR is a variation of the
+ Cyclical Learning Rate (CLR) policy that involves one cycle followed by
+ decay. The policy simultaneously cycles the learning rate (and momentum)
+ between two boundaries with a constant frequency, as detailed in
+ the paper `A disciplined approach to neural network hyper-parameters`_.
+
+ 1CLR policy changes the learning rate after every batch.
+ `step` should be called after a batch has been used for training.
+
+ This implementation was adapted from the github repo: `pytorch/pytorch`_
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ cycle_min_lr (float or list): Initial learning rate which is the
+ lower boundary in the cycle for each parameter group.
+ cycle_max_lr (float or list): Upper learning rate boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (cycle_max_lr - cycle_min_lr).
+ The lr at any cycle is the sum of cycle_min_lr
+ and some scaling of the amplitude; therefore
+ cycle_max_lr may not actually be reached depending on
+ scaling function.
+ decay_lr_rate(float): Decay rate for learning rate. Default: 0.
+ cycle_first_step_size (int): Number of training iterations in the
+ increasing half of a cycle. Default: 2000
+ cycle_second_step_size (int): Number of training iterations in the
+ decreasing half of a cycle. If cycle_second_step_size is None,
+ it is set to cycle_first_step_size. Default: None
+ cycle_first_stair_count(int): Number of stairs in first half of cycle phase. This means
+ lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
+ cycle_second_stair_count(int): Number of stairs in second half of cycle phase. This means
+ lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
+ decay_step_size (int): Intervals for applying decay in decay phase. Default: 0, means no decay.
+ cycle_momentum (bool): If ``True``, momentum is cycled inversely
+ to learning rate between 'cycle_min_mom' and 'cycle_max_mom'.
+ Default: True
+ cycle_min_mom (float or list): Initial momentum which is the
+ lower boundary in the cycle for each parameter group.
+ Default: 0.8
+ cycle_max_mom (float or list): Upper momentum boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (cycle_max_mom - cycle_min_mom).
+ The momentum at any cycle is the difference of cycle_max_mom
+ and some scaling of the amplitude; therefore
+ cycle_min_mom may not actually be reached depending on
+ scaling function. Default: 0.9
+ decay_mom_rate (float): Decay rate for momentum. Default: 0.
+ last_batch_iteration (int): The index of the last batch. This parameter is used when
+ resuming a training job. Since `step()` should be invoked after each
+ batch instead of after each epoch, this number represents the total
+ number of *batches* computed, not the total number of epochs computed.
+ When last_batch_iteration=-1, the schedule is started from the beginning.
+ Default: -1
+
+ Example:
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
+ >>> scheduler = OneCycle(optimizer, 0.0001, 0.0010)
+ >>> data_loader = torch.utils.data.DataLoader(...)
+ >>> for epoch in range(10):
+ >>> for batch in data_loader:
+ >>> train_batch(...)
+ >>> scheduler.step()
+
+
+ .. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820
+ """
+
+ def __init__(self,
+ optimizer,
+ cycle_min_lr,
+ cycle_max_lr,
+ decay_lr_rate=0.,
+ cycle_first_step_size=2000,
+ cycle_second_step_size=None,
+ cycle_first_stair_count=0,
+ cycle_second_stair_count=None,
+ decay_step_size=0,
+ cycle_momentum=True,
+ cycle_min_mom=0.8,
+ cycle_max_mom=0.9,
+ decay_mom_rate=0.,
+ last_batch_iteration=-1):
+
+ self.optimizer = get_torch_optimizer(optimizer)
+
+ # Initialize cycle shape
+ self._initialize_cycle(cycle_first_step_size, cycle_second_step_size, cycle_first_stair_count,
+ cycle_second_stair_count, decay_step_size)
+
+ # Initialize cycle lr
+ self._initialize_lr(self.optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration)
+
+ # Initialize cyclic momentum
+ self.cycle_momentum = cycle_momentum
+ if cycle_momentum:
+ self._initialize_momentum(self.optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate,
+ last_batch_iteration)
+
+ # Initialize batch iteration tracker
+ self.last_batch_iteration = last_batch_iteration
+
+ # Configure cycle shape
+
+ def _initialize_cycle(self, cycle_first_step_size, cycle_second_step_size, cycle_first_stair_count,
+ cycle_second_stair_count, decay_step_size):
+ cycle_first_step_size = float(cycle_first_step_size)
+ cycle_second_step_size = float(
+ cycle_second_step_size) if cycle_second_step_size is not None else cycle_first_step_size
+
+ self.total_size = cycle_first_step_size + cycle_second_step_size
+ self.step_ratio = cycle_first_step_size / self.total_size
+ self.first_stair_count = cycle_first_stair_count
+ self.second_stair_count = cycle_first_stair_count if cycle_second_stair_count is None else cycle_second_stair_count
+ self.decay_step_size = decay_step_size
+
+ if math.isclose(self.decay_step_size, 0):
+ self.skip_lr_decay = True
+ self.skip_mom_decay = True
+ else:
+ self.skip_lr_decay = False
+ self.skip_mom_decay = False
+
+ # Configure lr schedule
+ def _initialize_lr(self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration):
+ self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups)
+ if last_batch_iteration == -1:
+ for lr, group in zip(self.min_lrs, optimizer.param_groups):
+ group['lr'] = lr
+
+ self.max_lrs = [cycle_max_lr] * len(optimizer.param_groups)
+ self.decay_lr_rate = decay_lr_rate
+
+ if math.isclose(self.decay_lr_rate, 0):
+ self.skip_lr_decay = True
+
+ # Configure momentum schedule
+ def _initialize_momentum(self, optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration):
+ if 'betas' not in optimizer.defaults:
+ optimizer_name = type(optimizer).__name__
+ print(
+ f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
+ )
+ self.cycle_momentum = False
+ return
+
+ self.decay_mom_rate = decay_mom_rate
+ self.min_moms = [(cycle_min_mom, 0.99)] * len(optimizer.param_groups)
+ self.max_moms = [(cycle_max_mom, 0.99)] * len(optimizer.param_groups)
+
+ if last_batch_iteration == -1:
+ for momentum, group in zip(self.min_moms, optimizer.param_groups):
+ group['betas'] = momentum
+
+ if math.isclose(self.decay_mom_rate, 0):
+ self.skip_mom_decay = True
+
+ def _get_scale_factor(self):
+ batch_iteration = (self.last_batch_iteration + 1)
+ cycle = math.floor(1 + batch_iteration / self.total_size)
+ x = 1. + batch_iteration / self.total_size - cycle
+ if x <= self.step_ratio:
+ scale_factor = x / self.step_ratio
+ else:
+ scale_factor = (x - 1) / (self.step_ratio - 1)
+
+ return scale_factor
+
+ def _get_cycle_mom(self):
+ scale_factor = self._get_scale_factor()
+ momentums = []
+ for base_betas, max_betas in zip(self.min_moms, self.max_moms):
+ cycle_min_mom = base_betas[0]
+ cycle_max_mom = max_betas[0]
+ base_height = (cycle_max_mom - cycle_min_mom) * scale_factor
+ momentum = cycle_max_mom - base_height
+ momentums.append((momentum, base_betas[1]))
+ return momentums
+
+ def _get_cycle_lr(self):
+ scale_factor = self._get_scale_factor()
+ lrs = []
+ for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs):
+ base_height = (cycle_max_lr - cycle_min_lr) * scale_factor
+ lr = cycle_min_lr + base_height
+ lrs.append(lr)
+
+ return lrs
+
+ def _get_decay_mom(self, decay_batch_iteration):
+ if self.skip_mom_decay:
+ return self.max_moms
+
+ decay_interval = decay_batch_iteration / self.decay_step_size
+ mom_decay_factor = (1 + self.decay_mom_rate * decay_interval)
+ momentums = [(beta0 * mom_decay_factor, beta1) for beta0, beta1 in self.max_moms]
+
+ return momentums
+
+ def _get_decay_lr(self, decay_batch_iteration):
+ """Calculates the learning rate at batch index. This function is used
+ after the cycle completes and post cycle decaying of lr/mom is enabled.
+ This function treats `self.last_batch_iteration` as the last batch index.
+ """
+ if self.skip_lr_decay:
+ return self.min_lrs
+
+ decay_interval = decay_batch_iteration / self.decay_step_size
+ lr_decay_factor = (1 + self.decay_lr_rate * decay_interval)
+ lrs = [cycle_min_lr / lr_decay_factor for cycle_min_lr in self.min_lrs]
+
+ return lrs
+
+ def get_lr(self):
+ """Calculates the learning rate at batch index. This function treats
+ `self.last_batch_iteration` as the last batch index.
+ """
+ if self.last_batch_iteration < self.total_size:
+ return self._get_cycle_lr()
+ return self._get_decay_lr(self.last_batch_iteration - self.total_size + 1)
+
+ def get_mom(self):
+ """Calculates the momentum at batch index. This function treats
+ `self.last_batch_iteration` as the last batch index.
+ """
+ if not self.cycle_momentum:
+ return None
+
+ if self.last_batch_iteration < self.total_size:
+ return self._get_cycle_mom()
+ return self._get_decay_mom(self.last_batch_iteration - self.total_size + 1)
+
+ def get_last_lr(self):
+ """ Return last computed learning rate by current scheduler.
+ """
+ assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
+ return self._last_lr
+
+ def step(self, batch_iteration=None):
+ """ Updates the optimizer with the learning rate for the last batch index.
+ `self.last_batch_iteration` is treated as the last batch index.
+
+ If self.cycle_momentum is true, also updates optimizer momentum.
+ """
+ if batch_iteration is None:
+ batch_iteration = self.last_batch_iteration + 1
+
+ self.last_batch_iteration = batch_iteration
+ for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
+ param_group['lr'] = lr
+ self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
+
+ if self.cycle_momentum:
+ momentums = self.get_mom()
+ for param_group, momentum in zip(self.optimizer.param_groups, momentums):
+ param_group['betas'] = momentum
+
+ def state_dict(self):
+ return {'last_batch_iteration': self.last_batch_iteration}
+
+ def load_state_dict(self, sd):
+ self.last_batch_iteration = sd['last_batch_iteration']
+
+
+
+class WarmupLR(object):
+ """Increase the learning rate of each parameter group from min lr to max lr
+ over warmup_num_steps steps, and then fix at max lr.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ warmup_min_lr (float or list): minimum learning rate. Default: 0
+ warmup_max_lr (float or list): maximum learning rate. Default: 0.001
+ warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
+ warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
+ last_batch_iteration (int): The index of the last batch. Default: -1.
+ Example:
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
+ >>> scheduler = WarmupLR(optimizer)
+ >>> data_loader = torch.utils.data.DataLoader(...)
+ >>> for epoch in range(10):
+ >>> for batch in data_loader:
+ >>> train_batch(...)
+ >>> scheduler.step()
+
+ """
+
+ def __init__(self,
+ optimizer: Optimizer,
+ warmup_min_lr: float = 0.0,
+ warmup_max_lr: float = 0.001,
+ warmup_num_steps: int = 1000,
+ warmup_type: str = WARMUP_LOG_RATE,
+ last_batch_iteration: int = -1):
+
+ self.optimizer = get_torch_optimizer(optimizer)
+
+ self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr")
+ self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr")
+ self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)]
+ self.warmup_num_steps = max(2, warmup_num_steps)
+ # Currently only support linear and log function
+ if warmup_type not in {WARMUP_LOG_RATE, WARMUP_LINEAR_RATE}:
+ print(f"Using unknown warmup_type: {warmup_type}. The increasing function "
+ f"is set to default (log)")
+ warmup_type = WARMUP_LOG_RATE
+ self.warmup_type = warmup_type
+ self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
+ self.last_batch_iteration = last_batch_iteration
+
+ def get_lr(self):
+ if self.last_batch_iteration < 0:
+ print("Attempting to get learning rate from scheduler before it has started")
+ return [0.0]
+ gamma = self._get_gamma()
+ return [min_lr + (delta_lr * gamma) for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)]
+
+ def get_last_lr(self):
+ """ Return last computed learning rate by current scheduler.
+ """
+ assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
+ return self._last_lr
+
+ def step(self, last_batch_iteration=None):
+ if last_batch_iteration is None:
+ last_batch_iteration = self.last_batch_iteration + 1
+ self.last_batch_iteration = last_batch_iteration
+ for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
+ param_group['lr'] = lr
+ self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
+
+ def state_dict(self):
+ return {'last_batch_iteration': self.last_batch_iteration}
+
+ def load_state_dict(self, sd):
+ self.last_batch_iteration = sd['last_batch_iteration']
+
+ def _get_gamma(self):
+ if self.last_batch_iteration < self.warmup_num_steps:
+ if self.warmup_type == WARMUP_LOG_RATE:
+ return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
+ elif self.warmup_type == WARMUP_LINEAR_RATE:
+ return self.last_batch_iteration / self.warmup_num_steps
+ return 1.0
+
+ def _format_param(self, optimizer, param_value, param_name):
+ if isinstance(param_value, list) or isinstance(param_value, tuple):
+ if len(param_value) != len(optimizer.param_groups):
+ raise ValueError("expected {} value for {}, got {}".format(len(optimizer.param_groups), param_name,
+ FileNotFoundError(param_value)))
+ return list(param_value)
+ return [param_value] * len(optimizer.param_groups)
+
+
+
+class WarmupDecayLR(WarmupLR):
+ """Increase the learning rate of each parameter group from min lr to max lr
+ over warmup_num_steps steps, and then decay at linear rate over the remaining training steps.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ total_num_steps (int): total number of training steps
+ warmup_min_lr (float or list): minimum learning rate. Default: 0
+ warmup_max_lr (float or list): maximum learning rate. Default: 0.001
+ warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
+ warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
+ last_batch_iteration (int): The index of the last batch. Default: -1.
+ Example:
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
+ >>> scheduler = WarmupDecayLR(optimizer, 1000000)
+ >>> data_loader = torch.utils.data.DataLoader(...)
+ >>> for epoch in range(10):
+ >>> for batch in data_loader:
+ >>> train_batch(...)
+ >>> scheduler.step()
+
+ """
+
+ def __init__(self,
+ optimizer: Optimizer,
+ total_num_steps: int,
+ warmup_min_lr: float = 0.0,
+ warmup_max_lr: float = 0.001,
+ warmup_num_steps: int = 1000,
+ warmup_type: str = WARMUP_LOG_RATE,
+ last_batch_iteration: int = -1):
+
+ self.total_num_steps = total_num_steps
+ super(WarmupDecayLR, self).__init__(optimizer, warmup_min_lr, warmup_max_lr, warmup_num_steps, warmup_type,
+ last_batch_iteration)
+ if self.total_num_steps < self.warmup_num_steps:
+ print('total_num_steps {} is less than warmup_num_steps {}'.format(
+ total_num_steps, warmup_num_steps))
+
+ def _get_gamma(self):
+ if self.last_batch_iteration < self.warmup_num_steps:
+ if self.warmup_type == WARMUP_LOG_RATE:
+ return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
+ elif self.warmup_type == WARMUP_LINEAR_RATE:
+ return self.last_batch_iteration / self.warmup_num_steps
+ return max(
+ 0.0,
+ float(self.total_num_steps - self.last_batch_iteration) /
+ float(max(1.0, self.total_num_steps - self.warmup_num_steps)))
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/__init__.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/attn_layers.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/attn_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba666e74d93e0d905f03b74b56aacc452877a272
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/attn_layers.py
@@ -0,0 +1,447 @@
+import torch
+import torch.nn as nn
+from typing import Tuple, Union, Optional
+
+from utils.npu_utils import is_npu_available
+if is_npu_available():
+ import torch_npu
+
+try:
+ import flash_attn
+ if hasattr(flash_attn, '__version__') and int(flash_attn.__version__[0]) == 2:
+ from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
+ from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention
+ else:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
+ from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention
+except Exception as e:
+ print(f'flash_attn import failed: {e}')
+
+
+class NpuFlashAttention(torch.nn.Module):
+ def __init__(self, attention_dropout=0):
+ super().__init__()
+ self.attention_dropout = attention_dropout
+
+ def forward(self, query, key, value):
+ heads = query.shape[2]
+ attention_mask = None
+ output = torch_npu.npu_fusion_attention(
+ query, key, value, heads, input_layout='BSND',
+ pse=None,
+ atten_mask=attention_mask,
+ scale=1.0 / math.sqrt(query.shape[-1]),
+ pre_tockens=65536,
+ next_tockens=65536,
+ keep_prob=1. - self.attention_dropout,
+ sync=False,
+ inner_precise=0,
+ )[0]
+ return output
+
+
+def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
+ """
+ Reshape frequency tensor for broadcasting it with another tensor.
+
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+ Args:
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
+ head_first (bool): head dimension first (except batch dim) or not.
+
+ Returns:
+ torch.Tensor: Reshaped frequency tensor.
+
+ Raises:
+ AssertionError: If the frequency tensor doesn't match the expected shape.
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
+ """
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+
+ if isinstance(freqs_cis, tuple):
+ # freqs_cis: (cos, sin) in real space
+ if head_first:
+ assert freqs_cis[0].shape == (
+ x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ else:
+ assert freqs_cis[0].shape == (
+ x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
+ else:
+ # freqs_cis: values in complex space
+ if head_first:
+ assert freqs_cis.shape == (
+ x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ else:
+ assert freqs_cis.shape == (
+ x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def rotate_half(x):
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: Optional[torch.Tensor],
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
+ head_first: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor.
+
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
+ returned as real tensors.
+
+ Args:
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
+ head_first (bool): head dimension first (except batch dim) or not.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+
+ """
+ xk_out = None
+ if isinstance(freqs_cis, tuple):
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
+ xq_out = xq * cos + rotate_half(xq) * sin
+ if xk is not None:
+ xk_out = xk * cos + rotate_half(xk) * sin
+ else:
+ xq_ = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+ if xk is not None:
+ xk_ = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+
+ return xq_out, xk_out
+
+
+class FlashSelfMHAModified(nn.Module):
+ """
+ Use QK Normalization.
+ """
+
+ def __init__(self,
+ dim,
+ num_heads,
+ qkv_bias=True,
+ qk_norm=False,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ device=None,
+ dtype=None,
+ norm_layer=nn.LayerNorm,
+ FAG_deterministic=False,
+ ):
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ assert self.dim % num_heads == 0, "self.kdim must be divisible by num_heads"
+ self.head_dim = self.dim // num_heads
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
+
+ self.Wqkv = nn.Linear(dim, 3 * dim, bias=qkv_bias, **factory_kwargs)
+ # TODO: eps should be 1 / 65530 if using fp16
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6,
+ dtype=torch.float32) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6,
+ dtype=torch.float32) if qk_norm else nn.Identity()
+ if is_npu_available():
+ self.inner_attn = NpuFlashAttention(attention_dropout=attn_drop)
+ else:
+ self.inner_attn = FlashSelfAttention(attention_dropout=attn_drop, deterministic=FAG_deterministic)
+ self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, freqs_cis_img=None):
+ """
+ Parameters
+ ----------
+ x: torch.Tensor
+ (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
+ freqs_cis_img: torch.Tensor
+ (batch, hidden_dim // 2), RoPE for image
+ """
+ b, s, d = x.shape
+
+ qkv = self.Wqkv(x)
+ qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, h, d]
+ q, k, v = qkv.unbind(dim=2) # [b, s, h, d]
+ input_dtype = qkv.dtype
+ q = self.q_norm(q) # [b, s, h, d]
+ k = self.k_norm(k)
+
+ # Apply RoPE if needed
+ if freqs_cis_img is not None:
+ qq, kk = apply_rotary_emb(q, k, freqs_cis_img)
+ assert qq.shape == q.shape and kk.shape == k.shape, f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
+ q, k = qq, kk
+
+ if is_npu_available():
+ q = q.to(input_dtype)
+ k = k.to(input_dtype)
+ if q.dtype not in [torch.float16, torch.bfloat16] or k.dtype not in [torch.float16, torch.bfloat16] or v.dtype not in [torch.float16, torch.bfloat16]:
+ raise ValueError("The dtype of q/k/v must be torch.float16 or torch.bfloat16.")
+ context = self.inner_attn(q, k, v)
+ else:
+ qkv = torch.stack([q, k, v], dim=2) # [b, s, 3, h, d]
+ qkv = qkv.to(input_dtype)
+ context = self.inner_attn(qkv)
+
+ out = self.out_proj(context.view(b, s, d))
+ out = self.proj_drop(out)
+
+ out_tuple = (out,)
+
+ return out_tuple
+
+
+class FlashCrossMHAModified(nn.Module):
+ """
+ Use QK Normalization.
+ """
+
+ def __init__(self,
+ qdim,
+ kdim,
+ num_heads,
+ qkv_bias=True,
+ qk_norm=False,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ device=None,
+ dtype=None,
+ norm_layer=nn.LayerNorm,
+ FAG_deterministic=False,
+ ):
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super().__init__()
+ self.qdim = qdim
+ self.kdim = kdim
+ self.num_heads = num_heads
+ assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
+ self.head_dim = self.qdim // num_heads
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
+
+ self.scale = self.head_dim ** -0.5
+
+ self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
+ self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
+
+ # TODO: eps should be 1 / 65530 if using fp16
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6,
+ dtype=torch.float32) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6,
+ dtype=torch.float32) if qk_norm else nn.Identity()
+
+ if is_npu_available():
+ self.inner_attn = NpuFlashAttention(attention_dropout=attn_drop)
+ else:
+ self.inner_attn = FlashCrossAttention(attention_dropout=attn_drop, deterministic=FAG_deterministic)
+
+ self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, y, freqs_cis_img=None):
+ """
+ Parameters
+ ----------
+ x: torch.Tensor
+ (batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim)
+ y: torch.Tensor
+ (batch, seqlen2, hidden_dim2)
+ freqs_cis_img: torch.Tensor
+ (batch, hidden_dim // num_heads), RoPE for image
+ """
+ b, s1, _ = x.shape # [b, s1, D]
+ _, s2, _ = y.shape # [b, s2, 1024]
+
+ q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
+ kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
+ input_dtype = kv.dtype
+ k, v = kv.unbind(dim=2) # [b, s2, h, d]
+
+ q = self.q_norm(q) # [b, s1, h, d]
+ k = self.k_norm(k) # [b, s2, h, d]
+
+ # Apply RoPE if needed
+ if freqs_cis_img is not None:
+ qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
+ assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
+ q = qq # [b, s1, h, d]
+
+ if is_npu_available():
+ q = q.to(input_dtype)
+ k = k.to(input_dtype)
+ if q.dtype not in [torch.float16, torch.bfloat16] or k.dtype not in [torch.float16, torch.bfloat16] or v.dtype not in [torch.float16, torch.bfloat16]:
+ raise ValueError("The dtype of q/k/v must be torch.float16 or torch.bfloat16.")
+ context = self.inner_attn(q, k, v)
+ else:
+ kv = torch.stack([k, v], dim=2) # [b, s1, 2, h, d]
+ q = q.to(input_dtype)
+ kv = kv.to(input_dtype)
+ context = self.inner_attn(q, kv) # [b, s1, h, d]
+
+ context = context.view(b, s1, -1) # [b, s1, D]
+
+ out = self.out_proj(context)
+ out = self.proj_drop(out)
+
+ out_tuple = (out,)
+
+ return out_tuple
+
+
+class CrossAttention(nn.Module):
+ """
+ Use QK Normalization.
+ """
+
+ def __init__(self,
+ qdim,
+ kdim,
+ num_heads,
+ qkv_bias=True,
+ qk_norm=False,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ device=None,
+ dtype=None,
+ norm_layer=nn.LayerNorm,
+ ):
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super().__init__()
+ self.qdim = qdim
+ self.kdim = kdim
+ self.num_heads = num_heads
+ assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
+ self.head_dim = self.qdim // num_heads
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
+ self.scale = self.head_dim ** -0.5
+
+ self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
+ self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
+
+ # TODO: eps should be 1 / 65530 if using fp16
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, y, freqs_cis_img=None):
+ """
+ Parameters
+ ----------
+ x: torch.Tensor
+ (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
+ y: torch.Tensor
+ (batch, seqlen2, hidden_dim2)
+ freqs_cis_img: torch.Tensor
+ (batch, hidden_dim // 2), RoPE for image
+ """
+ b, s1, c = x.shape # [b, s1, D]
+ _, s2, c = y.shape # [b, s2, 1024]
+
+ q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
+ kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
+ k, v = kv.unbind(dim=2) # [b, s, h, d]
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+
+ # Apply RoPE if needed
+ if freqs_cis_img is not None:
+ qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
+ assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
+ q = qq
+
+ q = q * self.scale
+ q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
+ k = k.permute(0, 2, 3, 1).contiguous() # k -> B, L2, H, C - B, H, C, L2
+ attn = q @ k # attn -> B, H, L1, L2
+ attn = attn.softmax(dim=-1) # attn -> B, H, L1, L2
+ attn = self.attn_drop(attn)
+ x = attn @ v.transpose(-2, -3) # v -> B, L2, H, C - B, H, L2, C x-> B, H, L1, C
+ context = x.transpose(1, 2) # context -> B, H, L1, C - B, L1, H, C
+
+ context = context.contiguous().view(b, s1, -1)
+
+ out = self.out_proj(context) # context.reshape - B, L1, -1
+ out = self.proj_drop(out)
+
+ out_tuple = (out,)
+
+ return out_tuple
+
+
+class Attention(nn.Module):
+ """
+ We rename some layer names to align with flash attention
+ """
+
+ def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0.,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.head_dim = self.dim // num_heads
+ # This assertion is aligned with flash attention
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
+ self.scale = self.head_dim ** -0.5
+
+ # qkv --> Wqkv
+ self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ # TODO: eps should be 1 / 65530 if using fp16
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.out_proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, freqs_cis_img=None):
+ B, N, C = x.shape
+ qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
+ q, k, v = qkv.unbind(0) # [b, h, s, d]
+ q = self.q_norm(q) # [b, h, s, d]
+ k = self.k_norm(k) # [b, h, s, d]
+
+ # Apply RoPE if needed
+ if freqs_cis_img is not None:
+ qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
+ assert qq.shape == q.shape and kk.shape == k.shape, \
+ f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
+ q, k = qq, kk
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1) # [b, h, s, d] @ [b, h, d, s]
+ attn = attn.softmax(dim=-1) # [b, h, s, s]
+ attn = self.attn_drop(attn)
+ x = attn @ v # [b, h, s, d]
+
+ x = x.transpose(1, 2).reshape(B, N, C) # [b, s, h, d]
+ x = self.out_proj(x)
+ x = self.proj_drop(x)
+
+ out_tuple = (x,)
+
+ return out_tuple
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/controlnet.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c051f2b66ca7df022e0be2c66552d8b02d3e521f
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/controlnet.py
@@ -0,0 +1,273 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models import ModelMixin
+from timm.models.vision_transformer import Mlp
+
+from .attn_layers import Attention, FlashCrossMHAModified, FlashSelfMHAModified, CrossAttention
+from .embedders import TimestepEmbedder, PatchEmbed, timestep_embedding
+from .norm_layers import RMSNorm
+from .poolers import AttentionPool
+
+from .models import FP32_Layernorm, FP32_SiLU, HunYuanDiTBlock
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
+
+class HunYuanControlNet(ModelMixin, ConfigMixin):
+ """
+ HunYuanDiT: Diffusion model with a Transformer backbone.
+
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
+
+ Parameters
+ ----------
+ args: argparse.Namespace
+ The arguments parsed by argparse.
+ input_size: tuple
+ The size of the input image.
+ patch_size: int
+ The size of the patch.
+ in_channels: int
+ The number of input channels.
+ hidden_size: int
+ The hidden size of the transformer backbone.
+ depth: int
+ The number of transformer blocks.
+ num_heads: int
+ The number of attention heads.
+ mlp_ratio: float
+ The ratio of the hidden size of the MLP in the transformer block.
+ log_fn: callable
+ The logging function.
+ """
+ @register_to_config
+ def __init__(
+ self, args,
+ input_size=(32, 32),
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ log_fn=print,
+ ):
+ super().__init__()
+ self.args = args
+ self.log_fn = log_fn
+ self.depth = depth
+ self.learn_sigma = args.learn_sigma
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.text_states_dim = args.text_states_dim
+ self.text_states_dim_t5 = args.text_states_dim_t5
+ self.text_len = args.text_len
+ self.text_len_t5 = args.text_len_t5
+ self.norm = args.norm
+
+ use_flash_attn = args.infer_mode == 'fa' or args.use_flash_attn
+ if use_flash_attn:
+ log_fn(f" Enable Flash Attention.")
+ qk_norm = args.qk_norm # See http://arxiv.org/abs/2302.05442 for details.
+
+ self.mlp_t5 = nn.Sequential(
+ nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True),
+ FP32_SiLU(),
+ nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True),
+ )
+ # learnable replace
+ self.text_embedding_padding = nn.Parameter(
+ torch.randn(self.text_len + self.text_len_t5, self.text_states_dim, dtype=torch.float32))
+
+ # Attention pooling
+ self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=1024)
+
+ # Here we use a default learned embedder layer for future extension.
+ self.style_embedder = nn.Embedding(1, hidden_size)
+
+ # Image size and crop size conditions
+ self.extra_in_dim = 256 * 6 + hidden_size
+
+ # Text embedding for `add`
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.extra_in_dim += 1024
+ self.extra_embedder = nn.Sequential(
+ nn.Linear(self.extra_in_dim, hidden_size * 4),
+ FP32_SiLU(),
+ nn.Linear(hidden_size * 4, hidden_size, bias=True),
+ )
+
+ # Image embedding
+ num_patches = self.x_embedder.num_patches
+ log_fn(f" Number of tokens: {num_patches}")
+
+ # HUnYuanDiT Blocks
+ self.blocks = nn.ModuleList([
+ HunYuanDiTBlock(hidden_size=hidden_size,
+ c_emb_size=hidden_size,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ text_states_dim=self.text_states_dim,
+ use_flash_attn=use_flash_attn,
+ qk_norm=qk_norm,
+ norm_type=self.norm,
+ skip=False,
+ )
+ for _ in range(19)
+ ])
+
+ # Input zero linear for the first block
+ self.before_proj = zero_module(nn.Linear(self.hidden_size, self.hidden_size))
+
+ # Output zero linear for the every block
+ self.after_proj_list = nn.ModuleList(
+ [zero_module(nn.Linear(self.hidden_size, self.hidden_size)) for _ in range(len(self.blocks))]
+ )
+
+ self.fix_weight_modules = ['mlp_t5', 'text_embedding_padding', 'pooler', 'style_embedder', 'x_embedder', 't_embedder', 'extra_embedder']
+
+
+ def from_dit(self, dit):
+ """
+ Load the parameters from a pre-trained HunYuanDiT model.
+
+ Parameters
+ ----------
+ dit: HunYuanDiT
+ The pre-trained HunYuanDiT model.
+ """
+
+
+ self.mlp_t5.load_state_dict(dit.mlp_t5.state_dict())
+
+ self.text_embedding_padding.data = dit.text_embedding_padding.data
+ self.pooler.load_state_dict(dit.pooler.state_dict())
+ self.style_embedder.load_state_dict(dit.style_embedder.state_dict())
+ self.x_embedder.load_state_dict(dit.x_embedder.state_dict())
+ self.t_embedder.load_state_dict(dit.t_embedder.state_dict())
+ self.extra_embedder.load_state_dict(dit.extra_embedder.state_dict())
+
+ for i, block in enumerate(self.blocks):
+ block.load_state_dict(dit.blocks[i].state_dict())
+
+ def set_trainable(self):
+
+ self.mlp_t5.requires_grad_(False)
+ self.text_embedding_padding.requires_grad_(False)
+ self.pooler.requires_grad_(False)
+ self.style_embedder.requires_grad_(False)
+ self.x_embedder.requires_grad_(False)
+ self.t_embedder.requires_grad_(False)
+ self.extra_embedder.requires_grad_(False)
+
+ self.blocks.requires_grad_(True)
+ self.before_proj.requires_grad_(True)
+ self.after_proj_list.requires_grad_(True)
+
+ self.blocks.train()
+ self.before_proj.train()
+ self.after_proj_list.train()
+
+
+
+ def forward(self,
+ x,
+ t,
+ condition,
+ encoder_hidden_states=None,
+ text_embedding_mask=None,
+ encoder_hidden_states_t5=None,
+ text_embedding_mask_t5=None,
+ image_meta_size=None,
+ style=None,
+ cos_cis_img=None,
+ sin_cis_img=None,
+ return_dict=True,
+ ):
+ """
+ Forward pass of the encoder.
+
+ Parameters
+ ----------
+ x: torch.Tensor
+ (B, D, H, W)
+ t: torch.Tensor
+ (B)
+ encoder_hidden_states: torch.Tensor
+ CLIP text embedding, (B, L_clip, D)
+ text_embedding_mask: torch.Tensor
+ CLIP text embedding mask, (B, L_clip)
+ encoder_hidden_states_t5: torch.Tensor
+ T5 text embedding, (B, L_t5, D)
+ text_embedding_mask_t5: torch.Tensor
+ T5 text embedding mask, (B, L_t5)
+ image_meta_size: torch.Tensor
+ (B, 6)
+ style: torch.Tensor
+ (B)
+ cos_cis_img: torch.Tensor
+ sin_cis_img: torch.Tensor
+ return_dict: bool
+ Whether to return a dictionary.
+ """
+ text_states = encoder_hidden_states # 2,77,1024
+ text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
+ text_states_mask = text_embedding_mask.bool() # 2,77
+ text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
+ b_t5, l_t5, c_t5 = text_states_t5.shape
+ text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5))
+ text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024
+ clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
+
+ clip_t5_mask = clip_t5_mask
+ text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states))
+
+ _, _, oh, ow = x.shape
+ th, tw = oh // self.patch_size, ow // self.patch_size
+
+ # ========================= Build time and image embedding =========================
+ t = self.t_embedder(t)
+ x = self.x_embedder(x)
+
+ # Get image RoPE embedding according to `reso`lution.
+ freqs_cis_img = (cos_cis_img, sin_cis_img)
+
+ # ========================= Concatenate all extra vectors =========================
+ # Build text tokens with pooling
+ extra_vec = self.pooler(encoder_hidden_states_t5)
+
+ # Build image meta size tokens
+ image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
+ if self.args.use_fp16:
+ image_meta_size = image_meta_size.half()
+ image_meta_size = image_meta_size.view(-1, 6 * 256)
+ extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
+
+ # Build style tokens
+ style_embedding = self.style_embedder(style)
+ extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
+
+ # Concatenate all extra vectors
+ c = t + self.extra_embedder(extra_vec) # [B, D]
+
+ # ========================= Deal with Condition =========================
+ condition = self.x_embedder(condition)
+
+ # ========================= Forward pass through HunYuanDiT blocks =========================
+ controls = []
+ x = x + self.before_proj(condition) # add condition
+ for layer, block in enumerate(self.blocks):
+ x = block(x, c, text_states, freqs_cis_img)
+ controls.append(self.after_proj_list[layer](x)) # zero linear for output
+
+
+ if return_dict:
+ return {'controls': controls}
+ return controls
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/ema.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..acc31ee296f7f30bb37f101fdd911539df54e20a
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/ema.py
@@ -0,0 +1,127 @@
+from collections import OrderedDict
+from copy import deepcopy
+
+import torch
+from deepspeed.utils import instrument_w_nvtx
+from pathlib import Path
+
+def requires_grad(model, flag=True):
+ """
+ Set requires_grad flag for all parameters in a model.
+ """
+ for p in model.parameters():
+ p.requires_grad = flag
+
+
+class EMA(object):
+ def __init__(self, args, model, device, logger):
+ if args.ema_dtype == 'fp32':
+ self.warmup = args.ema_warmup
+ self.update_after_step = 0
+ self.max_value = args.ema_decay if args.ema_decay is not None else 0.9999
+ self.inv_gamma = 1.0
+ self.power = args.ema_warmup_power if args.ema_warmup_power is not None else 2 / 3
+ self.min_value = 0.0
+ else:
+ self.warmup = args.ema_warmup
+ self.update_after_step = 0
+ self.max_value = args.ema_decay if args.ema_decay is not None else 0.992
+ self.inv_gamma = 1.0
+ self.power = args.ema_warmup_power if args.ema_warmup_power is not None else 0.446249
+ # 0.446249 == math.log(1 - 0.992) / math.log(50000)
+ self.min_value = 0.0
+
+ self.ema_reset_decay = args.ema_reset_decay
+ self.decay_steps = 0
+
+ if args.ema_dtype == 'none':
+ ema_dtype = 'fp16' if args.use_fp16 else 'fp32'
+ else:
+ ema_dtype = args.ema_dtype
+
+ # 由于module.half()和module.float()会发生inplace类型修改,因此需要先copy后修改类型
+ self.ema_model = deepcopy(model)
+ if ema_dtype == 'fp16':
+ self.ema_model = self.ema_model.half().to(device)
+ elif ema_dtype == 'fp32':
+ self.ema_model = self.ema_model.float().to(device)
+ else:
+ raise ValueError(f"Unknown EMA dtype {ema_dtype}.")
+
+ requires_grad(self.ema_model, False)
+
+ logger.info(f" Using EMA with date type {args.ema_dtype} "
+ f"(decay={args.ema_decay}, warmup={args.ema_warmup}, warmup_power={args.ema_warmup_power}, "
+ f"reset_decay={args.ema_reset_decay}).")
+
+ def get_decay(self):
+ """
+ @crowsonkb's notes on EMA Warmup:
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
+ at 215.4k steps).
+
+ @jarvizhang's notes on EMA max_value when enabling FP16:
+ If using FP16 for EMA, max_value=0.995 is better (Don't larger than 0.999, unless you know
+ what you are doing). This is because FP16 has less precision than FP32, so the EMA value can
+ be pushed out of the range of FP16.
+
+ gamma=1, power=0.446249 are good values for models (reaches decay factor 0.99 at 30K steps,
+ 0.992 at 50K steps).
+ """
+ if self.warmup:
+ step = max(0, self.decay_steps - self.update_after_step - 1)
+ value = 1 - (1 + step / self.inv_gamma) ** -self.power
+
+ if step <= 0:
+ return 0.0
+
+ return max(self.min_value, min(value, self.max_value))
+ else:
+ return self.max_value
+
+ @torch.no_grad()
+ @instrument_w_nvtx
+ def update(self, model, step, decay=None):
+ """
+ Step the EMA model towards the current model.
+
+ Parameters
+ ----------
+ model: nn.Module
+ The current model
+ step: int
+ The current training step. This is used to determine the decay factor. If you want to control
+ the decay, you can pass in a custom step instead.
+ For example, if you want to restart the EMA decay, you can pass in step=0 at start and increase
+ step by step.
+ decay: float
+ The decay factor. If None, will be determined by the current step.
+ """
+ if decay is None:
+ if self.ema_reset_decay:
+ self.decay_steps += 1
+ else:
+ self.decay_steps = step
+ decay = self.get_decay()
+
+ ema_params = OrderedDict(self.ema_model.named_parameters())
+ model_params = OrderedDict(model.named_parameters())
+ for name, param in model_params.items():
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
+
+ return None
+
+ def state_dict(self, *args, **kwargs):
+ return self.ema_model.state_dict(*args, **kwargs)
+
+ def load_state_dict(self, *args, **kwargs):
+ return self.ema_model.load_state_dict(*args, **kwargs)
+
+ def train(self):
+ self.ema_model.train()
+
+ def eval(self):
+ self.ema_model.eval()
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/embedders.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/embedders.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fe08cba22eef41ca9fd9f70fe6f062a4dd606c8
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/embedders.py
@@ -0,0 +1,111 @@
+import math
+import torch
+import torch.nn as nn
+from einops import repeat
+
+from timm.models.layers import to_2tuple
+
+
+class PatchEmbed(nn.Module):
+ """ 2D Image to Patch Embedding
+
+ Image to Patch Embedding using Conv2d
+
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
+
+ Based on the impl in https://github.com/google-research/vision_transformer
+
+ Hacked together by / Copyright 2020 Ross Wightman
+
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
+ """
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ norm_layer=None,
+ flatten=True,
+ bias=True,
+ ):
+ super().__init__()
+ if isinstance(img_size, int):
+ img_size = to_2tuple(img_size)
+ elif isinstance(img_size, (tuple, list)) and len(img_size) == 2:
+ img_size = tuple(img_size)
+ else:
+ raise ValueError(f"img_size must be int or tuple/list of length 2. Got {img_size}")
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def update_image_size(self, img_size):
+ self.img_size = img_size
+ self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+
+ def forward(self, x):
+ # B, C, H, W = x.shape
+ # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
+ # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x
+
+
+def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ else:
+ embedding = repeat(t, "b -> b d", d=dim)
+ return embedding
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None):
+ super().__init__()
+ if out_size is None:
+ out_size = hidden_size
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, out_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ def forward(self, t):
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/fp16_layers.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/fp16_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6812811045eccbacde3a60fc3d230cd3fb9da0ad
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/fp16_layers.py
@@ -0,0 +1,86 @@
+import torch
+from torch.autograd import Variable
+from torch.nn.parameter import Parameter
+
+
+_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
+_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
+
+
+def conversion_helper(val, conversion):
+ """Apply conversion to val. Recursively apply conversion if `val`
+ #is a nested tuple/list structure."""
+ if isinstance(val, dict):
+ res_dict = {}
+ for k, v in val.items():
+ if k!= 'cos_cis_img' and k != 'sin_cis_img':
+ res_dict[k] = conversion_helper(v, conversion)
+ else:
+ res_dict[k] = v
+ return res_dict
+ if not isinstance(val, (tuple, list)):
+ return conversion(val)
+ rtn = [conversion_helper(v, conversion) for v in val]
+ if isinstance(val, tuple):
+ rtn = tuple(rtn)
+ return rtn
+
+
+def fp32_to_float16(val, float16_convertor):
+ """Convert fp32 `val` to fp16/bf16"""
+ def half_conversion(val):
+ val_typecheck = val
+ if isinstance(val_typecheck, (Parameter, Variable)):
+ val_typecheck = val.data
+ if isinstance(val_typecheck, _FLOAT_TYPES):
+ val = float16_convertor(val)
+ return val
+ return conversion_helper(val, half_conversion)
+
+
+def float16_to_fp32(val):
+ """Convert fp16/bf16 `val` to fp32"""
+ def float_conversion(val):
+ val_typecheck = val
+ if isinstance(val_typecheck, (Parameter, Variable)):
+ val_typecheck = val.data
+ if isinstance(val_typecheck, (_HALF_TYPES,)):
+ val = val.float()
+ return val
+ return conversion_helper(val, float_conversion)
+
+
+class Float16Module(torch.nn.Module):
+
+ def __init__(self, module, args):
+ super(Float16Module, self).__init__()
+
+ self.add_module('module', module.half())
+
+ def float16_convertor(val):
+ return val.half()
+
+ self.float16_convertor = float16_convertor
+
+ self.config = self.module.config
+ self.dtype = torch.float16
+
+ def forward(self, *inputs, **kwargs):
+ inputs = fp32_to_float16(inputs, self.float16_convertor)
+ kwargs = fp32_to_float16(kwargs, self.float16_convertor)
+ outputs = self.module(*inputs, **kwargs)
+ outputs = float16_to_fp32(outputs)
+ return outputs
+
+ def state_dict(self, destination=None, prefix='', keep_vars=False):
+ return self.module.state_dict(destination, prefix, keep_vars)
+
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
+ keep_vars=False):
+ return self.module.state_dict_for_save_checkpoint(destination, prefix,
+ keep_vars)
+
+ def load_state_dict(self, state_dict, strict=True):
+ self.module.load_state_dict(state_dict, strict=strict)
+
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/models.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..69dfc4384f1ecf579e8ac32a91c264d21ac04573
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/models.py
@@ -0,0 +1,508 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models import ModelMixin
+from timm.models.vision_transformer import Mlp
+
+from .attn_layers import Attention, FlashCrossMHAModified, FlashSelfMHAModified, CrossAttention
+from .embedders import TimestepEmbedder, PatchEmbed, timestep_embedding
+from .norm_layers import RMSNorm
+from .poolers import AttentionPool
+
+from transformers.integrations import PeftAdapterMixin
+from typing import Any, Optional, Union
+from tqdm import tqdm
+from peft.utils import (
+ ModulesToSaveWrapper,
+ _get_submodules,
+)
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+class FP32_Layernorm(nn.LayerNorm):
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ origin_dtype = inputs.dtype
+ return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(),
+ self.eps).to(origin_dtype)
+
+
+class FP32_SiLU(nn.SiLU):
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype)
+
+
+class HunYuanDiTBlock(nn.Module):
+ """
+ A HunYuanDiT block with `add` conditioning.
+ """
+ def __init__(self,
+ hidden_size,
+ c_emb_size,
+ num_heads,
+ mlp_ratio=4.0,
+ text_states_dim=1024,
+ use_flash_attn=False,
+ qk_norm=False,
+ norm_type="layer",
+ skip=False,
+ ):
+ super().__init__()
+ self.use_flash_attn = use_flash_attn
+ use_ele_affine = True
+
+ if norm_type == "layer":
+ norm_layer = FP32_Layernorm
+ elif norm_type == "rms":
+ norm_layer = RMSNorm
+ else:
+ raise ValueError(f"Unknown norm_type: {norm_type}")
+
+ # ========================= Self-Attention =========================
+ self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
+ if use_flash_attn:
+ self.attn1 = FlashSelfMHAModified(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm)
+ else:
+ self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm)
+
+ # ========================= FFN =========================
+ self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
+
+ # ========================= Add =========================
+ # Simply use add like SDXL.
+ self.default_modulation = nn.Sequential(
+ FP32_SiLU(),
+ nn.Linear(c_emb_size, hidden_size, bias=True)
+ )
+
+ # ========================= Cross-Attention =========================
+ if use_flash_attn:
+ self.attn2 = FlashCrossMHAModified(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
+ qk_norm=qk_norm)
+ else:
+ self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
+ qk_norm=qk_norm)
+ self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
+
+ # ========================= Skip Connection =========================
+ if skip:
+ self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6)
+ self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
+ else:
+ self.skip_linear = None
+
+ def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
+ # Long Skip Connection
+ if self.skip_linear is not None:
+ cat = torch.cat([x, skip], dim=-1)
+ cat = self.skip_norm(cat)
+ x = self.skip_linear(cat)
+
+ # Self-Attention
+ shift_msa = self.default_modulation(c).unsqueeze(dim=1)
+ attn_inputs = (
+ self.norm1(x) + shift_msa, freq_cis_img,
+ )
+ x = x + self.attn1(*attn_inputs)[0]
+
+ # Cross-Attention
+ cross_inputs = (
+ self.norm3(x), text_states, freq_cis_img
+ )
+ x = x + self.attn2(*cross_inputs)[0]
+
+ # FFN Layer
+ mlp_inputs = self.norm2(x)
+ x = x + self.mlp(mlp_inputs)
+
+ return x
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of HunYuanDiT.
+ """
+ def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ FP32_SiLU(),
+ nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class HunYuanDiT(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ HunYuanDiT: Diffusion model with a Transformer backbone.
+
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
+
+ Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
+
+ Parameters
+ ----------
+ args: argparse.Namespace
+ The arguments parsed by argparse.
+ input_size: tuple
+ The size of the input image.
+ patch_size: int
+ The size of the patch.
+ in_channels: int
+ The number of input channels.
+ hidden_size: int
+ The hidden size of the transformer backbone.
+ depth: int
+ The number of transformer blocks.
+ num_heads: int
+ The number of attention heads.
+ mlp_ratio: float
+ The ratio of the hidden size of the MLP in the transformer block.
+ log_fn: callable
+ The logging function.
+ """
+ @register_to_config
+ def __init__(
+ self, args,
+ input_size=(32, 32),
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ log_fn=print,
+ ):
+ super().__init__()
+ self.args = args
+ self.log_fn = log_fn
+ self.depth = depth
+ self.learn_sigma = args.learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels * 2 if args.learn_sigma else in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.text_states_dim = args.text_states_dim
+ self.text_states_dim_t5 = args.text_states_dim_t5
+ self.text_len = args.text_len
+ self.text_len_t5 = args.text_len_t5
+ self.norm = args.norm
+
+ use_flash_attn = args.infer_mode == 'fa' or args.use_flash_attn
+ if use_flash_attn:
+ log_fn(f" Enable Flash Attention.")
+ qk_norm = args.qk_norm # See http://arxiv.org/abs/2302.05442 for details.
+
+ self.mlp_t5 = nn.Sequential(
+ nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True),
+ FP32_SiLU(),
+ nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True),
+ )
+ # learnable replace
+ self.text_embedding_padding = nn.Parameter(
+ torch.randn(self.text_len + self.text_len_t5, self.text_states_dim))
+
+ # Attention pooling
+ self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=1024)
+
+ # Here we use a default learned embedder layer for future extension.
+ self.style_embedder = nn.Embedding(1, hidden_size)
+
+ # Image size and crop size conditions
+ self.extra_in_dim = 256 * 6 + hidden_size
+
+ # Text embedding for `add`
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.extra_in_dim += 1024
+ self.extra_embedder = nn.Sequential(
+ nn.Linear(self.extra_in_dim, hidden_size * 4),
+ FP32_SiLU(),
+ nn.Linear(hidden_size * 4, hidden_size, bias=True),
+ )
+
+ # Image embedding
+ num_patches = self.x_embedder.num_patches
+ log_fn(f" Number of tokens: {num_patches}")
+
+ # HUnYuanDiT Blocks
+ self.blocks = nn.ModuleList([
+ HunYuanDiTBlock(hidden_size=hidden_size,
+ c_emb_size=hidden_size,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ text_states_dim=self.text_states_dim,
+ use_flash_attn=use_flash_attn,
+ qk_norm=qk_norm,
+ norm_type=self.norm,
+ skip=layer > depth // 2,
+ )
+ for layer in range(depth)
+ ])
+
+ self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels)
+ self.unpatchify_channels = self.out_channels
+
+ self.initialize_weights()
+
+ def forward(self,
+ x,
+ t,
+ encoder_hidden_states=None,
+ text_embedding_mask=None,
+ encoder_hidden_states_t5=None,
+ text_embedding_mask_t5=None,
+ image_meta_size=None,
+ style=None,
+ cos_cis_img=None,
+ sin_cis_img=None,
+ return_dict=True,
+ controls=None,
+ ):
+ """
+ Forward pass of the encoder.
+
+ Parameters
+ ----------
+ x: torch.Tensor
+ (B, D, H, W)
+ t: torch.Tensor
+ (B)
+ encoder_hidden_states: torch.Tensor
+ CLIP text embedding, (B, L_clip, D)
+ text_embedding_mask: torch.Tensor
+ CLIP text embedding mask, (B, L_clip)
+ encoder_hidden_states_t5: torch.Tensor
+ T5 text embedding, (B, L_t5, D)
+ text_embedding_mask_t5: torch.Tensor
+ T5 text embedding mask, (B, L_t5)
+ image_meta_size: torch.Tensor
+ (B, 6)
+ style: torch.Tensor
+ (B)
+ cos_cis_img: torch.Tensor
+ sin_cis_img: torch.Tensor
+ return_dict: bool
+ Whether to return a dictionary.
+ """
+ text_states = encoder_hidden_states # 2,77,1024
+ text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
+ text_states_mask = text_embedding_mask.bool() # 2,77
+ text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
+ b_t5, l_t5, c_t5 = text_states_t5.shape
+ text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5))
+ text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024
+ clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
+
+ clip_t5_mask = clip_t5_mask
+ text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states))
+
+ _, _, oh, ow = x.shape
+ th, tw = oh // self.patch_size, ow // self.patch_size
+
+ # ========================= Build time and image embedding =========================
+ t = self.t_embedder(t)
+ x = self.x_embedder(x)
+
+ # Get image RoPE embedding according to `reso`lution.
+ freqs_cis_img = (cos_cis_img, sin_cis_img)
+
+ # ========================= Concatenate all extra vectors =========================
+ # Build text tokens with pooling
+ extra_vec = self.pooler(encoder_hidden_states_t5)
+
+ # Build image meta size tokens
+ image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
+ if self.args.use_fp16 and self.args.autocast_dtype == "fp16":
+ image_meta_size = image_meta_size.half()
+ image_meta_size = image_meta_size.view(-1, 6 * 256)
+ extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
+
+ # Build style tokens
+ style_embedding = self.style_embedder(style)
+ extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
+
+ # Concatenate all extra vectors
+ c = t + self.extra_embedder(extra_vec) # [B, D]
+
+ # ========================= Forward pass through HunYuanDiT blocks =========================
+ skips = []
+ for layer, block in enumerate(self.blocks):
+ if layer > self.depth // 2:
+ if controls is not None:
+ skip = skips.pop() + controls.pop()
+ else:
+ skip = skips.pop()
+ x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
+ else:
+ x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
+
+ if layer < (self.depth // 2 - 1):
+ skips.append(x)
+ if controls is not None and len(controls) != 0:
+ raise ValueError("The number of controls is not equal to the number of skip connections.")
+
+ # ========================= Final layer =========================
+ x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
+ x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
+
+ if return_dict:
+ return {'x': x}
+ return x
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.extra_embedder[0].weight, std=0.02)
+ nn.init.normal_(self.extra_embedder[2].weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in HunYuanDiT blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.default_modulation[-1].weight, 0)
+ nn.init.constant_(block.default_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x, h, w):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.unpatchify_channels
+ p = self.x_embedder.patch_size[0]
+ # h = w = int(x.shape[1] ** 0.5)
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
+ return imgs
+
+ def _replace_module(self, parent, child_name, new_module, child) -> None:
+ setattr(parent, child_name, new_module)
+ # It's not necessary to set requires_grad here, as that is handled by
+ # _mark_only_adapters_as_trainable
+
+ # child layer wraps the original module, unpack it
+ if hasattr(child, "base_layer"):
+ child = child.get_base_layer()
+ elif hasattr(child, "quant_linear_module"):
+ # TODO maybe not necessary to have special treatment?
+ child = child.quant_linear_module
+
+ if not hasattr(new_module, "base_layer"):
+ new_module.weight = child.weight
+ if hasattr(child, "bias"):
+ new_module.bias = child.bias
+
+ if getattr(child, "state", None) is not None:
+ if hasattr(new_module, "base_layer"):
+ new_module.base_layer.state = child.state
+ else:
+ new_module.state = child.state
+ new_module.to(child.weight.device)
+
+ # dispatch to correct device
+ for name, module in new_module.named_modules():
+ # if any(prefix in name for prefix in PREFIXES):
+ # module.to(child.weight.device)
+ if "ranknum" in name:
+ module.to(child.weight.device)
+
+ def merge_and_unload(self,
+ merge=True,
+ progressbar: bool = False,
+ safe_merge: bool = False,
+ adapter_names = None,):
+ if merge:
+ if getattr(self, "quantization_method", None) == "gptq":
+ raise ValueError("Cannot merge layers when the model is gptq quantized")
+
+ def merge_recursively(module):
+ # helper function to recursively merge the base_layer of the target
+ path = []
+ layer = module
+ while hasattr(layer, "base_layer"):
+ path.append(layer)
+ layer = layer.base_layer
+ for layer_before, layer_after in zip(path[:-1], path[1:]):
+ layer_after.merge(safe_merge=safe_merge, adapter_names=adapter_names)
+ layer_before.base_layer = layer_after.base_layer
+ module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
+
+ key_list = [key for key, _ in self.named_modules()]
+ desc = "Unloading " + ("and merging " if merge else "") + "model"
+
+ for key in tqdm(key_list, disable=not progressbar, desc=desc):
+ try:
+ parent, target, target_name = _get_submodules(self, key)
+ except AttributeError:
+ continue
+
+ if hasattr(target, "base_layer"):
+ if merge:
+ merge_recursively(target)
+ self._replace_module(parent, target_name, target.get_base_layer(), target)
+ elif isinstance(target, ModulesToSaveWrapper):
+ # save any additional trainable modules part of `modules_to_save`
+ new_module = target.modules_to_save[target.active_adapter]
+ if hasattr(new_module, "base_layer"):
+ # check if the module is itself a tuner layer
+ if merge:
+ new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
+ new_module = new_module.get_base_layer()
+ setattr(parent, target_name, new_module)
+
+
+
+#################################################################################
+# HunYuanDiT Configs #
+#################################################################################
+
+HUNYUAN_DIT_CONFIG = {
+ 'DiT-g/2': {'depth': 40, 'hidden_size': 1408, 'patch_size': 2, 'num_heads': 16, 'mlp_ratio': 4.3637},
+ 'DiT-XL/2': {'depth': 28, 'hidden_size': 1152, 'patch_size': 2, 'num_heads': 16},
+}
+
+def DiT_g_2(args, **kwargs):
+ return HunYuanDiT(args, depth=40, hidden_size=1408, patch_size=2, num_heads=16, mlp_ratio=4.3637, **kwargs)
+def DiT_XL_2(args, **kwargs):
+ return HunYuanDiT(args, depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
+
+HUNYUAN_DIT_MODELS = {
+ 'DiT-g/2': DiT_g_2,
+ 'DiT-XL/2': DiT_XL_2,
+}
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/norm_layers.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/norm_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c25bc3f3b99047372982943027942cf7f2ded3a
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/norm_layers.py
@@ -0,0 +1,74 @@
+import torch
+import torch.nn as nn
+from utils.npu_utils import is_npu_available
+if is_npu_available():
+ import torch_npu
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ super().__init__()
+ self.eps = eps
+ if is_npu_available() or elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ """
+ Apply the RMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ """
+ Forward pass through the RMSNorm layer.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying RMSNorm.
+
+ """
+ if is_npu_available():
+ output = torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]
+ else:
+ output = self._norm(x.float()).type_as(x)
+ if hasattr(self, "weight"):
+ output = output * self.weight
+ return output
+
+
+class GroupNorm32(nn.GroupNorm):
+ def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None):
+ super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype)
+
+ def forward(self, x):
+ y = super().forward(x).to(x.dtype)
+ return y
+
+def normalization(channels, dtype=None):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/poolers.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/poolers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4adcaca51fded2268a644ca4c70d5b33dfcd3b0
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/poolers.py
@@ -0,0 +1,39 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class AttentionPool(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.permute(1, 0, 2) # NLC -> LNC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x[:1], key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+ return x.squeeze(0)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/posemb_layers.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/posemb_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcb41a713cd94ea8472ff26e8865066887b1e486
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/posemb_layers.py
@@ -0,0 +1,224 @@
+import torch
+import numpy as np
+from typing import Union
+
+
+def _to_tuple(x):
+ if isinstance(x, int):
+ return x, x
+ else:
+ return x
+
+
+def get_fill_resize_and_crop(src, tgt):
+ th, tw = _to_tuple(tgt)
+ h, w = _to_tuple(src)
+
+ tr = th / tw # base resolution
+ r = h / w # target resolution
+
+ # resize
+ if r > tr:
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h)) # resize the target resolution down based on the base resolution
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+def get_meshgrid(start, *args):
+ if len(args) == 0:
+ # start is grid_size
+ num = _to_tuple(start)
+ start = (0, 0)
+ stop = num
+ elif len(args) == 1:
+ # start is start, args[0] is stop, step is 1
+ start = _to_tuple(start)
+ stop = _to_tuple(args[0])
+ num = (stop[0] - start[0], stop[1] - start[1])
+ elif len(args) == 2:
+ # start is start, args[0] is stop, args[1] is num
+ start = _to_tuple(start)
+ stop = _to_tuple(args[0])
+ num = _to_tuple(args[1])
+ else:
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
+
+ grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
+ grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0) # [2, W, H]
+ return grid
+
+#################################################################################
+# Sine/Cosine Positional Embedding Functions #
+#################################################################################
+# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+
+def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid = get_meshgrid(start, *args) # [2, H, w]
+ # grid_h = np.arange(grid_size, dtype=np.float32)
+ # grid_w = np.arange(grid_size, dtype=np.float32)
+ # grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ # grid = np.stack(grid, axis=0) # [2, W, H]
+
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (W,H)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+#################################################################################
+# Rotary Positional Embedding Functions #
+#################################################################################
+# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
+
+def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
+ """
+ This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
+
+ Parameters
+ ----------
+ embed_dim: int
+ embedding dimension size
+ start: int or tuple of int
+ If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
+ If len(args) == 2, start is start, args[0] is stop, args[1] is num.
+ use_real: bool
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+
+ Returns
+ -------
+ pos_embed: torch.Tensor
+ [HW, D/2]
+ """
+ grid = get_meshgrid(start, *args) # [2, H, w]
+ grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
+ return pos_embed
+
+
+def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
+ assert embed_dim % 4 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
+ emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
+
+ if use_real:
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
+ return cos, sin
+ else:
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
+ return emb
+
+
+def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
+ The returned tensor contains complex values in complex64 data type.
+
+ Args:
+ dim (int): Dimension of the frequency tensor.
+ pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (bool, optional): If True, return real part and imaginary part separately.
+ Otherwise, return complex numbers.
+
+ Returns:
+ torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
+
+ """
+ if isinstance(pos, int):
+ pos = np.arange(pos)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
+ if use_real:
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
+ return freqs_cis
+
+
+
+def calc_sizes(rope_img, patch_size, th, tw):
+ if rope_img == 'extend':
+ # Expansion mode
+ sub_args = [(th, tw)]
+ elif rope_img.startswith('base'):
+ # Based on the specified dimensions, other dimensions are obtained through interpolation.
+ base_size = int(rope_img[4:]) // 8 // patch_size
+ start, stop = get_fill_resize_and_crop((th, tw), base_size)
+ sub_args = [start, stop, (th, tw)]
+ else:
+ raise ValueError(f"Unknown rope_img: {rope_img}")
+ return sub_args
+
+
+def init_image_posemb(rope_img,
+ resolutions,
+ patch_size,
+ hidden_size,
+ num_heads,
+ log_fn,
+ rope_real=True,
+ ):
+ freqs_cis_img = {}
+ for reso in resolutions:
+ th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
+ sub_args = calc_sizes(rope_img, patch_size, th, tw)
+ freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
+ log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
+ f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
+ return freqs_cis_img
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/text_encoder.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a16b21f4ffbd9acc896c93e2833caf5aecfa0a2
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/text_encoder.py
@@ -0,0 +1,95 @@
+import torch
+import torch.nn as nn
+from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
+
+
+class MT5Embedder(nn.Module):
+ available_models = ["t5-v1_1-xxl"]
+
+ def __init__(
+ self,
+ model_dir="t5-v1_1-xxl",
+ model_kwargs=None,
+ torch_dtype=None,
+ use_tokenizer_only=False,
+ conditional_generation=False,
+ max_length=128,
+ ):
+ super().__init__()
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.torch_dtype = torch_dtype or torch.bfloat16
+ self.max_length = max_length
+ if model_kwargs is None:
+ model_kwargs = {
+ # "low_cpu_mem_usage": True,
+ "torch_dtype": self.torch_dtype,
+ }
+ model_kwargs["device_map"] = {"shared": self.device, "encoder": self.device}
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
+ if use_tokenizer_only:
+ return
+ if conditional_generation:
+ self.model = None
+ self.generation_model = T5ForConditionalGeneration.from_pretrained(
+ model_dir
+ )
+ return
+ self.model = T5EncoderModel.from_pretrained(model_dir, **model_kwargs).eval().to(self.torch_dtype)
+
+ def get_tokens_and_mask(self, texts):
+ text_tokens_and_mask = self.tokenizer(
+ texts,
+ max_length=self.max_length,
+ padding="max_length",
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ tokens = text_tokens_and_mask["input_ids"][0]
+ mask = text_tokens_and_mask["attention_mask"][0]
+ # tokens = torch.tensor(tokens).clone().detach()
+ # mask = torch.tensor(mask, dtype=torch.bool).clone().detach()
+ return tokens, mask
+
+ def get_text_embeddings(self, texts, attention_mask=True, layer_index=-1):
+ text_tokens_and_mask = self.tokenizer(
+ texts,
+ max_length=self.max_length,
+ padding="max_length",
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+
+ with torch.no_grad():
+ outputs = self.model(
+ input_ids=text_tokens_and_mask["input_ids"].to(self.device),
+ attention_mask=text_tokens_and_mask["attention_mask"].to(self.device)
+ if attention_mask
+ else None,
+ output_hidden_states=True,
+ )
+ text_encoder_embs = outputs["hidden_states"][layer_index].detach()
+
+ return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device)
+
+ @torch.no_grad()
+ def __call__(self, tokens, attention_mask, layer_index=-1):
+ with torch.cuda.amp.autocast():
+ outputs = self.model(
+ input_ids=tokens,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ )
+
+ z = outputs.hidden_states[layer_index].detach()
+ return z
+
+ def general(self, text: str):
+ # input_ids = input_ids = torch.tensor([list(text.encode("utf-8"))]) + num_special_tokens
+ input_ids = self.tokenizer(text, max_length=128).input_ids
+ print(input_ids)
+ outputs = self.generation_model(input_ids)
+ return outputs
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/trt/engine.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/trt/engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..bba110edf70dd4fd2c6df218a2d3aa1f915d6617
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/trt/engine.py
@@ -0,0 +1,152 @@
+#
+# Copyright 2022 The HuggingFace Inc. team.
+# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from collections import OrderedDict
+from copy import copy
+
+import numpy as np
+import tensorrt as trt
+import torch
+from polygraphy import cuda
+from polygraphy.backend.common import bytes_from_path
+from polygraphy.backend.trt import CreateConfig, Profile
+from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine
+from polygraphy.backend.trt import util as trt_util
+import ctypes
+from glob import glob
+from cuda import cudart
+
+TRT_LOGGER = trt.Logger(trt.Logger.INFO)
+trt_util.TRT_LOGGER = TRT_LOGGER
+
+
+class Engine():
+ def __init__(
+ self,
+ model_name,
+ engine_dir,
+ onnx_file=None,
+ ):
+ self.engine_path = os.path.join(engine_dir, model_name + '.plan')
+ self.engine = None
+ self.context = None
+ self.buffers = OrderedDict()
+ self.tensors = OrderedDict()
+
+ self.weightNameList = None
+ self.refitter = None
+ self.onnx_initializers = None
+ self.onnx_file = onnx_file
+ self.trt_lora_weight = None
+ self.trt_lora_weight_mem = None
+ self.torch_weight = None
+
+ def __del__(self):
+ del self.engine
+ del self.context
+ del self.buffers
+ del self.tensors
+
+ def build(self, onnx_path, fp16, input_profile=None, enable_preview=False, sparse_weights=False):
+ print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
+ p = Profile()
+ if input_profile:
+ for name, dims in input_profile.items():
+ assert len(dims) == 3
+ p.add(name, min=dims[0], opt=dims[1], max=dims[2])
+
+ preview_features = []
+ if enable_preview:
+ trt_version = [int(i) for i in trt.__version__.split(".")]
+ # FASTER_DYNAMIC_SHAPES_0805 should only be used for TRT 8.5.1 or above.
+ if trt_version[0] > 8 or \
+ (trt_version[0] == 8 and (trt_version[1] > 5 or (trt_version[1] == 5 and trt_version[2] >= 1))):
+ preview_features = [trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805]
+
+ engine = engine_from_network(network_from_onnx_path(onnx_path), config=CreateConfig(fp16=fp16, profiles=[p],
+ preview_features=preview_features,
+ sparse_weights=sparse_weights))
+ save_engine(engine, path=self.engine_path)
+
+ def activate(self, plugin_path=""):
+ ctypes.cdll.LoadLibrary(plugin_path)
+ self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
+ self.context = self.engine.create_execution_context()
+
+ def get_shared_memory(self):
+ _, device_memory = cudart.cudaMalloc(self.engine.device_memory_size)
+ self.device_memory = device_memory
+ return self.device_memory
+
+ def set_shared_memory(self, device_memory_size):
+ self.context.device_memory = device_memory_size
+
+ def binding_input(self, name, shape):
+ idx = self.engine.get_binding_index(name)
+ result = self.context.set_binding_shape(idx, shape)
+ return result
+
+ def allocate_buffers(self, shape_dict=None, device='cuda'):
+ print("Allocate buffers and bindings inputs:")
+ for idx in range(trt_util.get_bindings_per_profile(self.engine)):
+ binding = self.engine[idx]
+ print("binding: ", binding)
+ if shape_dict and binding in shape_dict:
+ shape = shape_dict[binding]
+ else:
+ shape = self.engine.get_binding_shape(binding)
+ nv_dtype = self.engine.get_binding_dtype(binding)
+ dtype_map = {trt.DataType.FLOAT: np.float32,
+ trt.DataType.HALF: np.float16,
+ trt.DataType.INT8: np.int8,
+ trt.DataType.INT64: np.int64,
+ trt.DataType.BOOL: bool}
+ if hasattr(trt.DataType, 'INT32'):
+ dtype_map[trt.DataType.INT32] = np.int32
+ dtype = dtype_map[nv_dtype]
+ if self.engine.binding_is_input(binding):
+ self.context.set_binding_shape(idx, shape)
+ # Workaround to convert np dtype to torch
+ np_type_tensor = np.empty(shape=[], dtype=dtype)
+ torch_type_tensor = torch.from_numpy(np_type_tensor)
+ tensor = torch.empty(tuple(shape), dtype=torch_type_tensor.dtype).to(device=device)
+
+ print(f" binding={binding}, shape={shape}, dtype={tensor.dtype}")
+ self.tensors[binding] = tensor
+ self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
+
+ def infer(self, feed_dict, stream):
+ start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
+ # shallow copy of ordered dict
+ device_buffers = copy(self.buffers)
+ for name, buf in feed_dict.items():
+ assert isinstance(buf, cuda.DeviceView)
+ device_buffers[name] = buf
+ self.binding_input(name, buf.shape)
+ bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
+ noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
+ if not noerror:
+ raise ValueError(f"ERROR: inference failed.")
+
+ for idx in range(trt_util.get_bindings_per_profile(self.engine)):
+ binding = self.engine[idx]
+ if not self.engine.binding_is_input(binding):
+ shape = self.context.get_binding_shape(idx)
+ self.tensors[binding].resize_(tuple(shape))
+ return self.tensors
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/trt/hcf_model.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/trt/hcf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0ded21c8192af6c03568f81eecf078eb56226a3
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/modules/trt/hcf_model.py
@@ -0,0 +1,123 @@
+import numpy as np
+import torch
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models import ModelMixin
+from polygraphy import cuda
+
+from .engine import Engine
+
+
+class TRTModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ in_channels=4,
+ model_name="unet-dyn",
+ engine_dir="./unet",
+ device_id=0,
+ fp16=True,
+ image_width=1024,
+ image_height=1024,
+ text_maxlen=77,
+ embedding_dim=768,
+ max_batch_size=1,
+ plugin_path="./ckpts/trt_model/fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so",
+ ):
+ super().__init__()
+ # create engine
+ self.in_channels = in_channels # For pipeline compatibility
+ self.fp16 = fp16
+ self.max_batch_size = max_batch_size
+ self.model_name = model_name
+ self.engine_dir = engine_dir
+ self.engine = Engine(self.model_name, self.engine_dir)
+ self.engine.activate(plugin_path)
+ # create cuda stream
+ self.stream = cuda.Stream()
+ # create inputs buffer
+ self.latent_width = image_width // 8
+ self.latent_height = image_height // 8
+ self.text_maxlen = text_maxlen
+ self.embedding_dim = embedding_dim
+ shape_dict = {
+ 'x': (2 * self.max_batch_size, 4, self.latent_height, self.latent_width),
+ 't': (2 * self.max_batch_size,),
+ 'encoder_hidden_states': (2 * self.max_batch_size, self.text_maxlen, self.embedding_dim),
+ 'text_embedding_mask': (2 * self.max_batch_size, self.text_maxlen),
+ 'encoder_hidden_states_t5': (2 * self.max_batch_size, 256, 2048),
+ 'text_embedding_mask_t5': (2 * self.max_batch_size, 256),
+ 'image_meta_size': (2 * self.max_batch_size, 6),
+ 'style': (2 * self.max_batch_size,),
+ 'cos_cis_img': (6400, 88),
+ 'sin_cis_img': (6400, 88),
+ 'output': (2 * self.max_batch_size, 8, self.latent_height, self.latent_width),
+ }
+ device = "cuda:{}".format(device_id)
+ self.engine_device = torch.device(device)
+ self.engine.allocate_buffers(shape_dict=shape_dict, device=device)
+
+ print("[INFO] Create hcf nv controlled unet success")
+
+ @property
+ def device(self):
+ return self.engine_device
+
+ def __call__(self, x, t_emb, context, image_meta_size, style, freqs_cis_img0,
+ freqs_cis_img1, text_embedding_mask, encoder_hidden_states_t5, text_embedding_mask_t5):
+ return self.forward(x=x, t_emb=t_emb, context=context, image_meta_size=image_meta_size, style=style,
+ freqs_cis_img0=freqs_cis_img0, freqs_cis_img1=freqs_cis_img1,
+ text_embedding_mask=text_embedding_mask, encoder_hidden_states_t5=encoder_hidden_states_t5,
+ text_embedding_mask_t5=text_embedding_mask_t5)
+
+ def get_shared_memory(self):
+ return self.engine.get_shared_memory()
+
+ def set_shared_memory(self, shared_memory):
+ self.engine.set_shared_memory(shared_memory)
+
+ def forward(self, x, t_emb, context, image_meta_size, style, freqs_cis_img0,
+ freqs_cis_img1, text_embedding_mask, encoder_hidden_states_t5, text_embedding_mask_t5):
+ x_c = x.half()
+ t_emb_c = t_emb.half()
+ context_c = context.half()
+ image_meta_size_c = image_meta_size.half()
+ style_c = style.long()
+ freqs_cis_img0_c = freqs_cis_img0.float()
+ freqs_cis_img1_c = freqs_cis_img1.float()
+ text_embedding_mask_c = text_embedding_mask.long()
+ encoder_hidden_states_t5_c = encoder_hidden_states_t5.half()
+ text_embedding_mask_t5_c = text_embedding_mask_t5.long()
+ dtype = np.float16
+ batch_size = x.shape[0] // 2
+ if batch_size <= self.max_batch_size:
+ sample_inp = cuda.DeviceView(ptr=x_c.reshape(-1).data_ptr(), shape=x_c.shape, dtype=np.float16)
+ t_emb_inp = cuda.DeviceView(ptr=t_emb_c.reshape(-1).data_ptr(), shape=t_emb_c.shape, dtype=np.float16)
+ embeddings_inp = cuda.DeviceView(ptr=context_c.reshape(-1).data_ptr(), shape=context_c.shape,
+ dtype=np.float16)
+ image_meta_size_inp = cuda.DeviceView(ptr=image_meta_size_c.reshape(-1).data_ptr(),
+ shape=image_meta_size_c.shape, dtype=np.float16)
+ style_inp = cuda.DeviceView(ptr=style_c.reshape(-1).data_ptr(), shape=style_c.shape, dtype=np.int64)
+ freqs_cis_img0_inp = cuda.DeviceView(ptr=freqs_cis_img0_c.reshape(-1).data_ptr(),
+ shape=freqs_cis_img0_c.shape, dtype=np.float32)
+ freqs_cis_img1_inp = cuda.DeviceView(ptr=freqs_cis_img1_c.reshape(-1).data_ptr(),
+ shape=freqs_cis_img1_c.shape, dtype=np.float32)
+ text_embedding_mask_inp = cuda.DeviceView(ptr=text_embedding_mask_c.reshape(-1).data_ptr(),
+ shape=text_embedding_mask_c.shape, dtype=np.int64)
+ encoder_hidden_states_t5_inp = cuda.DeviceView(ptr=encoder_hidden_states_t5_c.reshape(-1).data_ptr(),
+ shape=encoder_hidden_states_t5_c.shape, dtype=np.float16)
+ text_embedding_mask_t5_inp = cuda.DeviceView(ptr=text_embedding_mask_t5_c.reshape(-1).data_ptr(),
+ shape=text_embedding_mask_t5_c.shape, dtype=np.int64)
+ feed_dict = {"x": sample_inp,
+ "t": t_emb_inp,
+ "encoder_hidden_states": embeddings_inp,
+ "image_meta_size": image_meta_size_inp,
+ "text_embedding_mask": text_embedding_mask_inp,
+ "encoder_hidden_states_t5": encoder_hidden_states_t5_inp,
+ "text_embedding_mask_t5": text_embedding_mask_t5_inp,
+ "style": style_inp, "cos_cis_img": freqs_cis_img0_inp,
+ "sin_cis_img": freqs_cis_img1_inp}
+ latent = self.engine.infer(feed_dict, self.stream)
+ return latent['output']
+ else:
+ raise ValueError(
+ "[ERROR] Input batch_size={} execeed max_batch_size={}".format(batch_size, self.max_batch_size))
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/run_g.sh b/PyTorch/built-in/mlm/HunyuanDiT/hydit/run_g.sh
new file mode 100644
index 0000000000000000000000000000000000000000..63a9e30f9fe3f30f334d2f976f528be3d94628af
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/run_g.sh
@@ -0,0 +1,31 @@
+model='DiT-g/2'
+params=" \
+ --qk-norm \
+ --model ${model} \
+ --rope-img base512 \
+ --rope-real \
+ "
+deepspeed --num_gpus 8 --num_nodes 1 --master_port=29000 hydit/train_deepspeed.py ${params} "$@"
+
+#HOSTFILE="/home/l50041210/HunyuanDiT_combine/hostfile"
+#MASTER_ADDR=$(head -n1 $HOSTFILE | awk '{print $1;}')
+#MASTER_PORT=6001
+#NODE_ADDR=`hostname -I | awk '{for(i=1;i<=NF;i++)print $i}' | grep ${MASTER_ADDR%.*}. | awk -F " "'{print$1}'`
+#NODE_RANK=$(awk '{ranks[$1]=(FNR-1);}END{print ranks["'$NODE_ADDR'"];}' $HOSTFILE)
+#NNODES=$(cat $HOSTFILE | wc -l)
+#NPUS_PER_NODE=8
+#WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
+#echo $MASTER_ADDR
+#echo $NODE_ADDR
+#echo $NODE_RANK
+#echo $NNODES
+#
+#DISTRIBUTED_ARGS="
+# --nproc_per_node $NPUS_PER_NODE \
+# --nnodes $NNODES \
+# --node_rank $NODE_RANK \
+# --master_addr $MASTER_ADDR \
+# --master_port $MASTER_PORT
+##"
+#
+#torchrun $DISTRIBUTED_ARGS hydit/train_deepspeed.py ${params} "$@"
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/run_g_controlnet.sh b/PyTorch/built-in/mlm/HunyuanDiT/hydit/run_g_controlnet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ab4d854d6d65b6a430c69ecb87054a72878662ec
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/run_g_controlnet.sh
@@ -0,0 +1,8 @@
+model='DiT-g/2'
+params=" \
+ --qk-norm \
+ --model ${model} \
+ --rope-img base512 \
+ --rope-real \
+ "
+deepspeed hydit/train_deepspeed_controlnet.py ${params} "$@"
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/train.sh b/PyTorch/built-in/mlm/HunyuanDiT/hydit/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..10c856dfe1064716d2227a4ca53385da709c5160
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/train.sh
@@ -0,0 +1,46 @@
+task_flag="dit_g2_full_1024p" # the task flag is used to identify folders.
+resume=./ckpts/t2i/model/ # checkpoint root for resume
+index_file=dataset/porcelain/jsons/porcelain_mt.json # index file for dataloader
+results_dir=./log_EXP # save root for results
+batch_size=1 # training batch size
+image_size=1024 # training image resolution
+grad_accu_steps=1 # gradient accumulation
+warmup_num_steps=0 # warm-up steps
+lr=0.0001 # learning rate
+ckpt_every=10000 # create a ckpt every a few steps.
+ckpt_latest_every=5000 # create a ckpt named `latest.pt` every a few steps.
+
+
+sh $(dirname "$0")/run_g.sh \
+ --task-flag ${task_flag} \
+ --noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.03 \
+ --predict-type v_prediction \
+ --uncond-p 0.44 \
+ --uncond-p-t5 0.44 \
+ --index-file ${index_file} \
+ --random-flip \
+ --lr ${lr} \
+ --batch-size ${batch_size} \
+ --image-size ${image_size} \
+ --global-seed 999 \
+ --grad-accu-steps ${grad_accu_steps} \
+ --warmup-num-steps ${warmup_num_steps} \
+ --use-flash-attn \
+ --use-fp16 \
+ --use-ema \
+ --ema-dtype fp32 \
+ --results-dir ${results_dir} \
+ --resume-split \
+ --resume ${resume} \
+ --ckpt-every ${ckpt_every} \
+ --ckpt-latest-every ${ckpt_latest_every} \
+ --log-every 1 \
+ --deepspeed \
+ --deepspeed-optimizer \
+ --use-zero-stage 2 \
+ --multireso \
+ --reso-step 64 \
+ --max-training-steps 5000 \
+ --norm 'layer' \
+ --autocast-dtype 'bf16'
+ "$@"
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/train_controlnet.sh b/PyTorch/built-in/mlm/HunyuanDiT/hydit/train_controlnet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..09c21603e5f31283d5a154b849a6c149d9256e2b
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/train_controlnet.sh
@@ -0,0 +1,46 @@
+task_flag="canny_controlnet" # the task flag is used to identify folders.
+control_type=canny
+resume=./ckpts/t2i/model/ # checkpoint root for resume
+index_file=/path/to/your/indexfile # index file for dataloader
+results_dir=./log_EXP # save root for results
+batch_size=1 # training batch size
+image_size=1024 # training image resolution
+grad_accu_steps=2 # gradient accumulation
+warmup_num_steps=0 # warm-up steps
+lr=0.0001 # learning rate
+ckpt_every=10000 # create a ckpt every a few steps.
+ckpt_latest_every=5000 # create a ckpt named `latest.pt` every a few steps.
+
+
+sh $(dirname "$0")/run_g_controlnet.sh \
+ --task-flag ${task_flag} \
+ --control-type ${control_type} \
+ --noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.03 \
+ --predict-type v_prediction \
+ --multireso \
+ --reso-step 64 \
+ --ema-to-module \
+ --uncond-p 0.44 \
+ --uncond-p-t5 0.44 \
+ --index-file ${index_file} \
+ --random-flip \
+ --lr ${lr} \
+ --batch-size ${batch_size} \
+ --image-size ${image_size} \
+ --global-seed 999 \
+ --grad-accu-steps ${grad_accu_steps} \
+ --warmup-num-steps ${warmup_num_steps} \
+ --use-flash-attn \
+ --use-fp16 \
+ --use-ema \
+ --ema-dtype fp32 \
+ --results-dir ${results_dir} \
+ --resume-split \
+ --resume ${resume} \
+ --ckpt-every ${ckpt_every} \
+ --ckpt-latest-every ${ckpt_latest_every} \
+ --log-every 10 \
+ --deepspeed \
+ --deepspeed-optimizer \
+ --use-zero-stage 2 \
+ "$@"
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/train_deepspeed.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/train_deepspeed.py
new file mode 100644
index 0000000000000000000000000000000000000000..05ce0655f821dbb21d80d3d04297f43b61020f82
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/train_deepspeed.py
@@ -0,0 +1,545 @@
+import gc
+import json
+import os
+import random
+import sys
+import time
+from functools import partial
+from glob import glob
+from pathlib import Path
+import numpy as np
+
+import deepspeed
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+from torch.distributed.optim import ZeroRedundancyOptimizer
+from torchvision.transforms import functional as TF
+from diffusers.models import AutoencoderKL
+from transformers import BertModel, BertTokenizer, logging as tf_logging
+
+from hydit.config import get_args
+from hydit.constants import VAE_EMA_PATH, TEXT_ENCODER, TOKENIZER, T5_ENCODER
+from hydit.lr_scheduler import WarmupLR
+from hydit.data_loader.arrow_load_stream import TextImageArrowStream
+from hydit.diffusion import create_diffusion
+from hydit.ds_config import deepspeed_config_from_args
+from hydit.modules.ema import EMA
+from hydit.modules.fp16_layers import Float16Module
+from hydit.modules.models import HUNYUAN_DIT_MODELS
+from hydit.modules.posemb_layers import init_image_posemb
+from hydit.utils.tools import create_logger, set_seeds, create_exp_folder, model_resume, get_trainable_params
+from IndexKits.index_kits import ResolutionGroup
+from IndexKits.index_kits.sampler import DistributedSamplerWithStartIndex, BlockDistributedSampler
+from peft import LoraConfig, get_peft_model
+
+from utils.npu_utils import is_npu_available, seed_all, AUTOCAST_MAPPING
+
+if is_npu_available():
+ import torch_npu
+ from torch_npu.npu.amp import autocast
+ from torch_npu.contrib import transfer_to_npu
+else:
+ from torch.cuda.amp import autocast
+
+
+def deepspeed_initialize(args, logger, model, opt, deepspeed_config):
+ logger.info(f"Initialize deepspeed...")
+ logger.info(f" Using deepspeed optimizer")
+
+ def get_learning_rate_scheduler(warmup_min_lr, lr, warmup_num_steps, opt):
+ return WarmupLR(opt, warmup_min_lr, lr, warmup_num_steps)
+
+ logger.info(f" Building scheduler with warmup_min_lr={args.warmup_min_lr}, warmup_num_steps={args.warmup_num_steps}")
+ model, opt, _, scheduler = deepspeed.initialize(
+ model=model,
+ model_parameters=get_trainable_params(model),
+ config_params=deepspeed_config,
+ args=args,
+ lr_scheduler=partial(get_learning_rate_scheduler, args.warmup_min_lr,
+ args.lr, args.warmup_num_steps) if args.warmup_num_steps > 0 else None,
+ )
+ return model, opt, scheduler
+
+
+def save_checkpoint(args, rank, logger, model, ema, epoch, train_steps, checkpoint_dir):
+ def save_lora_weight(checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt"):
+ cur_ckpt_save_dir = f"{checkpoint_dir}/{tag}"
+ if rank == 0:
+ if args.use_fp16:
+ model.module.module.save_pretrained(cur_ckpt_save_dir)
+ else:
+ model.module.save_pretrained(cur_ckpt_save_dir)
+
+ checkpoint_path = "[Not rank 0. Disabled output.]"
+
+ client_state = {
+ "steps": train_steps,
+ "epoch": epoch,
+ "args": args
+ }
+ if ema is not None:
+ client_state['ema'] = ema.state_dict()
+
+ dst_paths = []
+ if train_steps % args.ckpt_every == 0:
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
+ try:
+ if args.training_parts == "lora":
+ save_lora_weight(checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt")
+ else:
+ model.save_checkpoint(checkpoint_dir, client_state=client_state, tag=f"{train_steps:07d}.pt")
+ dst_paths.append(checkpoint_path)
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
+ except:
+ logger.error(f"Saved failed to {checkpoint_path}")
+
+ if train_steps % args.ckpt_latest_every == 0 or train_steps == args.max_training_steps:
+ save_name = "latest.pt"
+ checkpoint_path = f"{checkpoint_dir}/{save_name}"
+ try:
+ if args.training_parts == "lora":
+ save_lora_weight(checkpoint_dir, client_state, tag=f"{save_name}")
+ else:
+ model.save_checkpoint(checkpoint_dir, client_state=client_state, tag=f"{save_name}")
+ dst_paths.append(checkpoint_path)
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
+ except:
+ logger.error(f"Saved failed to {checkpoint_path}")
+
+ dist.barrier()
+ if rank == 0 and len(dst_paths) > 0:
+ # Delete optimizer states to avoid occupying too much disk space.
+ for dst_path in dst_paths:
+ for opt_state_path in glob(f"{dst_path}/zero_dp_rank_*_tp_rank_00_pp_rank_00_optim_states.pt"):
+ os.remove(opt_state_path)
+
+ return checkpoint_path
+
+
+@torch.no_grad()
+def prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img):
+ image, text_embedding, text_embedding_mask, text_embedding_t5, text_embedding_mask_t5, kwargs = batch
+
+ # clip & mT5 text embedding
+ text_embedding = text_embedding.to(device)
+ text_embedding_mask = text_embedding_mask.to(device)
+ encoder_hidden_states = text_encoder(
+ text_embedding.to(device),
+ attention_mask=text_embedding_mask.to(device),
+ )[0]
+ text_embedding_t5 = text_embedding_t5.to(device).squeeze(1)
+ text_embedding_mask_t5 = text_embedding_mask_t5.to(device).squeeze(1)
+ with torch.no_grad():
+ output_t5 = text_encoder_t5(
+ input_ids=text_embedding_t5,
+ attention_mask=text_embedding_mask_t5 if T5_ENCODER['attention_mask'] else None,
+ output_hidden_states=True
+ )
+ encoder_hidden_states_t5 = output_t5['hidden_states'][T5_ENCODER['layer_index']].detach()
+
+ # additional condition
+ image_meta_size = kwargs['image_meta_size'].to(device)
+ style = kwargs['style'].to(device)
+
+ if args.extra_fp16:
+ image = image.half()
+ image_meta_size = image_meta_size.half() if image_meta_size is not None else None
+
+ # Map input images to latent space + normalize latents:
+ image = image.to(device)
+ vae_scaling_factor = vae.config.scaling_factor
+ latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor)
+
+ # positional embedding
+ _, _, height, width = image.shape
+ reso = f"{height}x{width}"
+ cos_cis_img, sin_cis_img = freqs_cis_img[reso]
+
+ # Model conditions
+ model_kwargs = dict(
+ encoder_hidden_states=encoder_hidden_states,
+ text_embedding_mask=text_embedding_mask,
+ encoder_hidden_states_t5=encoder_hidden_states_t5,
+ text_embedding_mask_t5=text_embedding_mask_t5,
+ image_meta_size=image_meta_size,
+ style=style,
+ cos_cis_img=cos_cis_img,
+ sin_cis_img=sin_cis_img,
+ )
+
+ return latents, model_kwargs
+
+
+def main(args):
+ if args.training_parts == "lora":
+ args.use_ema = False
+
+ assert torch.cuda.is_available(), "Training currently requires at least one GPU."
+
+ deepspeed.init_distributed('nccl') #adapt to deepspeed 0.14.4
+ world_size = dist.get_world_size()
+ batch_size = args.batch_size
+ grad_accu_steps = args.grad_accu_steps
+ global_batch_size = world_size * batch_size * grad_accu_steps
+
+ rank = dist.get_rank()
+ device = rank % torch.cuda.device_count()
+ seed = args.global_seed * world_size + rank
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.cuda.set_device(device)
+ print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.")
+ deepspeed_config = deepspeed_config_from_args(args, global_batch_size)
+
+ # Setup an experiment folder
+ experiment_dir, checkpoint_dir, logger = create_exp_folder(args, rank)
+
+ # Log all the arguments
+ logger.info(sys.argv)
+ logger.info(str(args))
+ # Save to a json file
+ args_dict = vars(args)
+ args_dict['world_size'] = world_size
+ with open(f"{experiment_dir}/args.json", 'w') as f:
+ json.dump(args_dict, f, indent=4)
+
+ # Disable the message "Some weights of the model checkpoint at ... were not used when initializing BertModel."
+ # If needed, just comment the following line.
+ tf_logging.set_verbosity_error()
+
+ # ===========================================================================
+ # Building HYDIT
+ # ===========================================================================
+
+ logger.info("Building HYDIT Model.")
+
+ # ---------------------------------------------------------------------------
+ # Training sample base size, such as 256/512/1024. Notice that this size is
+ # just a base size, not necessary the actual size of training samples. Actual
+ # size of the training samples are correlated with `resolutions` when enabling
+ # multi-resolution training.
+ # ---------------------------------------------------------------------------
+ image_size = args.image_size
+ if len(image_size) == 1:
+ image_size = [image_size[0], image_size[0]]
+ if len(image_size) != 2:
+ raise ValueError(f"Invalid image size: {args.image_size}")
+ assert image_size[0] % 8 == 0 and image_size[
+ 1] % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder). " \
+ f"got {image_size}"
+ latent_size = [image_size[0] // 8, image_size[1] // 8]
+
+ # initialize model by deepspeed
+ assert args.deepspeed, f"Must enable deepspeed in this script: train_deepspeed.py"
+ with deepspeed.zero.Init(data_parallel_group=torch.distributed.group.WORLD,
+ remote_device=None if args.remote_device == 'none' else args.remote_device,
+ config_dict_or_path=deepspeed_config,
+ mpu=None,
+ enabled=args.zero_stage == 3):
+ model = HUNYUAN_DIT_MODELS[args.model](args,
+ input_size=latent_size,
+ log_fn=logger.info,
+ )
+ # Multi-resolution / Single-resolution training.
+ if args.multireso:
+ resolutions = ResolutionGroup(image_size[0],
+ align=16,
+ step=args.reso_step,
+ target_ratios=args.target_ratios).data
+ else:
+ resolutions = ResolutionGroup(image_size[0],
+ align=16,
+ target_ratios=['1:1']).data
+
+ freqs_cis_img = init_image_posemb(args.rope_img,
+ resolutions=resolutions,
+ patch_size=model.patch_size,
+ hidden_size=model.hidden_size,
+ num_heads=model.num_heads,
+ log_fn=logger.info,
+ rope_real=args.rope_real,
+ )
+
+ # Create EMA model and convert to fp16 if needed.
+ ema = None
+ if args.use_ema:
+ ema = EMA(args, model, device, logger)
+
+ # Setup FP16 main model:
+ if args.use_fp16:
+ model = Float16Module(model, args)
+ logger.info(f" Using main model with data type {'fp16' if args.use_fp16 else 'fp32'}")
+
+ diffusion = create_diffusion(
+ noise_schedule=args.noise_schedule,
+ predict_type=args.predict_type,
+ learn_sigma=args.learn_sigma,
+ mse_loss_weight_type=args.mse_loss_weight_type,
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ noise_offset=args.noise_offset,
+ )
+
+ # Setup VAE
+ logger.info(f" Loading vae from {VAE_EMA_PATH}")
+ vae = AutoencoderKL.from_pretrained(VAE_EMA_PATH)
+ # Setup BERT text encoder
+ logger.info(f" Loading Bert text encoder from {TEXT_ENCODER}")
+ text_encoder = BertModel.from_pretrained(TEXT_ENCODER, False, revision=None)
+ # Setup BERT tokenizer:
+ logger.info(f" Loading Bert tokenizer from {TOKENIZER}")
+ tokenizer = BertTokenizer.from_pretrained(TOKENIZER)
+ # Setup T5 text encoder
+ from hydit.modules.text_encoder import MT5Embedder
+ mt5_path = T5_ENCODER['MT5']
+ embedder_t5 = MT5Embedder(mt5_path, torch_dtype=T5_ENCODER['torch_dtype'], max_length=args.text_len_t5)
+ tokenizer_t5 = embedder_t5.tokenizer
+ text_encoder_t5 = embedder_t5.model
+
+ if args.extra_fp16:
+ logger.info(f" Using fp16 for extra modules: vae, text_encoder")
+ vae = vae.half().to(device)
+ text_encoder = text_encoder.half().to(device)
+ text_encoder_t5 = text_encoder_t5.half().to(device)
+ else:
+ vae = vae.to(device)
+ text_encoder = text_encoder.to(device)
+ text_encoder_t5 = text_encoder_t5.to(device)
+
+ logger.info(f" Optimizer parameters: lr={args.lr}, weight_decay={args.weight_decay}")
+ logger.info(" Using deepspeed optimizer")
+ opt = None
+
+ # ===========================================================================
+ # Building Dataset
+ # ===========================================================================
+
+ logger.info(f"Building Streaming Dataset.")
+ logger.info(f" Loading index file {args.index_file} (v2)")
+
+ dataset = TextImageArrowStream(args=args,
+ resolution=image_size[0],
+ random_flip=args.random_flip,
+ log_fn=logger.info,
+ index_file=args.index_file,
+ multireso=args.multireso,
+ batch_size=batch_size,
+ world_size=world_size,
+ random_shrink_size_cond=args.random_shrink_size_cond,
+ merge_src_cond=args.merge_src_cond,
+ uncond_p=args.uncond_p,
+ text_ctx_len=args.text_len,
+ tokenizer=tokenizer,
+ uncond_p_t5=args.uncond_p_t5,
+ text_ctx_len_t5=args.text_len_t5,
+ tokenizer_t5=tokenizer_t5,
+ )
+ if args.multireso:
+ sampler = BlockDistributedSampler(dataset, num_replicas=world_size, rank=rank, seed=args.global_seed,
+ shuffle=False, drop_last=True, batch_size=batch_size)
+ else:
+ sampler = DistributedSamplerWithStartIndex(dataset, num_replicas=world_size, rank=rank, seed=args.global_seed,
+ shuffle=False, drop_last=True)
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler,
+ num_workers=args.num_workers, pin_memory=True, drop_last=True, persistent_workers=True)
+ logger.info(f" Dataset contains {len(dataset):,} images.")
+ logger.info(f" Index file: {args.index_file}.")
+ if args.multireso:
+ logger.info(f' Using MultiResolutionBucketIndexV2 with step {dataset.index_manager.step} '
+ f'and base size {dataset.index_manager.base_size}')
+ logger.info(f'\n {dataset.index_manager.resolutions}')
+
+ # ===========================================================================
+ # Loading parameter
+ # ===========================================================================
+
+ logger.info(f"Loading parameter")
+ start_epoch = 0
+ start_epoch_step = 0
+ train_steps = 0
+ # Resume checkpoint if needed
+ if args.resume is not None or len(args.resume) > 0:
+ model, ema, start_epoch, start_epoch_step, train_steps = model_resume(args, model, ema, logger, len(loader))
+
+ if args.training_parts == "lora":
+ loraconfig = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ target_modules=args.target_modules
+ )
+ if args.use_fp16:
+ model.module = get_peft_model(model.module, loraconfig)
+ else:
+ model = get_peft_model(model, loraconfig)
+
+ logger.info(f" Training parts: {args.training_parts}")
+
+ model, opt, scheduler = deepspeed_initialize(args, logger, model, opt, deepspeed_config)
+
+ # ===========================================================================
+ # Training
+ # ===========================================================================
+ with autocast(dtype=AUTOCAST_MAPPING[args.autocast_dtype]):
+ model.train()
+ if args.use_ema:
+ ema.eval()
+
+ print(f" Worker {rank} ready.")
+ dist.barrier()
+
+ iters_per_epoch = len(loader)
+ logger.info(" ****************************** Running training ******************************")
+ logger.info(f" Number GPUs: {world_size}")
+ logger.info(f" Number training samples: {len(dataset):,}")
+ logger.info(f" Number parameters: {sum(p.numel() for p in model.parameters()):,}")
+ logger.info(f" Number trainable params: {sum(p.numel() for p in get_trainable_params(model)):,}")
+ logger.info(" ------------------------------------------------------------------------------")
+ logger.info(f" Iters per epoch: {iters_per_epoch:,}")
+ logger.info(f" Batch size per device: {batch_size}")
+ logger.info(
+ f" Batch size all device: {batch_size * world_size * grad_accu_steps:,} (world_size * batch_size * grad_accu_steps)")
+ logger.info(f" Gradient Accu steps: {args.grad_accu_steps}")
+ logger.info(f" Total optimization steps: {args.epochs * iters_per_epoch // grad_accu_steps:,}")
+
+ logger.info(f" Training epochs: {start_epoch}/{args.epochs}")
+ logger.info(f" Training epoch steps: {start_epoch_step:,}/{iters_per_epoch:,}")
+ logger.info(
+ f" Training total steps: {train_steps:,}/{min(args.max_training_steps, args.epochs * iters_per_epoch):,}")
+ logger.info(" ------------------------------------------------------------------------------")
+ logger.info(f" Noise schedule: {args.noise_schedule}")
+ logger.info(f" Beta limits: ({args.beta_start}, {args.beta_end})")
+ logger.info(f" Learn sigma: {args.learn_sigma}")
+ logger.info(f" Prediction type: {args.predict_type}")
+ logger.info(f" Noise offset: {args.noise_offset}")
+
+ logger.info(" ------------------------------------------------------------------------------")
+ logger.info(f" Using EMA model: {args.use_ema} ({args.ema_dtype})")
+ if args.use_ema:
+ logger.info(f" Using EMA decay: {ema.max_value if args.use_ema else None}")
+ logger.info(f" Using EMA warmup power: {ema.power if args.use_ema else None}")
+ logger.info(f" Using main model fp16: {args.use_fp16}")
+ logger.info(f" Using extra modules fp16: {args.extra_fp16}")
+ logger.info(" ------------------------------------------------------------------------------")
+ logger.info(f" Experiment directory: {experiment_dir}")
+ logger.info(" *******************************************************************************")
+
+ if args.gc_interval > 0:
+ gc.disable()
+ gc.collect()
+
+ # Variables for monitoring/logging purposes:
+ log_steps = 0
+ running_loss = 0
+ start_time = time.time()
+
+ if args.async_ema:
+ ema_stream = torch.cuda.Stream()
+
+ # Training loop
+ for epoch in range(start_epoch, args.epochs):
+ logger.info(f" Start random shuffle with seed={seed}")
+ # Makesure all processors use the same seed to shuffle dataset.
+ dataset.shuffle(seed=args.global_seed + epoch, fast=True)
+ logger.info(f" End of random shuffle")
+
+ # Move sampler to start_index
+ if not args.multireso:
+ start_index = start_epoch_step * world_size * batch_size
+ if start_index != sampler.start_index:
+ sampler.start_index = start_index
+ # Reset start_epoch_step to zero, to ensure next epoch will start from the beginning.
+ start_epoch_step = 0
+ logger.info(f" Iters left this epoch: {len(loader):,}")
+
+ logger.info(f" Beginning epoch {epoch}...")
+ step = 0
+ for batch in loader:
+ step += 1
+ if args.cal_e2e:
+ if train_steps == 100:
+ e2e_start = time.time()
+ if train_steps == 200:
+ logger.info(f"(100-200步端到端耗时={time.time() - e2e_start:.4f})")
+ exit()
+
+ latents, model_kwargs = prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5,
+ freqs_cis_img)
+
+ # training model by deepspeed while use fp16
+ if args.use_fp16:
+ if args.use_ema and args.async_ema:
+ with torch.cuda.stream(ema_stream):
+ ema.update(model.module.module, step=step)
+ torch.cuda.current_stream().wait_stream(ema_stream)
+
+ loss_dict = diffusion.training_losses(model=model, x_start=latents, model_kwargs=model_kwargs,
+ step=train_steps)
+ loss = loss_dict["loss"].mean()
+ model.backward(loss)
+ last_batch_iteration = (train_steps + 1) // (global_batch_size // (batch_size * world_size))
+ model.step(lr_kwargs={'last_batch_iteration': last_batch_iteration})
+
+ if args.use_ema and not args.async_ema or (args.async_ema and step == len(loader) - 1):
+ if args.use_fp16:
+ ema.update(model.module.module, step=step)
+ else:
+ ema.update(model.module, step=step)
+
+ # ===========================================================================
+ # Log loss values:
+ # ===========================================================================
+ running_loss += loss.item()
+ log_steps += 1
+ train_steps += 1
+ if train_steps % args.log_every == 0:
+ # Measure training speed:
+ torch.cuda.synchronize()
+ end_time = time.time()
+ steps_per_sec = log_steps / (end_time - start_time)
+ # Reduce loss history over all processes:
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
+ avg_loss = avg_loss.item() / world_size
+ # get lr from deepspeed fused optimizer
+ logger.info(f"(step={train_steps:07d}) " +
+ (
+ f"(update_step={train_steps // args.grad_accu_steps:07d}) " if args.grad_accu_steps > 1 else "") +
+ f"Train Loss: {avg_loss:.4f}, "
+ f"Lr: {opt.param_groups[0]['lr']:.6g}, "
+ f"Steps/Sec: {steps_per_sec:.2f}, "
+ f"Millisec/Step: {(end_time - start_time) * 1000:.2f},"
+ f"Samples/Sec: {int(steps_per_sec * batch_size * world_size):d}")
+ # Reset monitoring variables:
+ running_loss = 0
+ log_steps = 0
+ start_time = time.time()
+
+ # collect gc:
+ if args.gc_interval > 0 and (step % args.gc_interval == 0):
+ gc.collect()
+
+ if (train_steps % args.ckpt_every == 0 or train_steps % args.ckpt_latest_every == 0) and train_steps > 0:
+ save_checkpoint(args, rank, logger, model, ema, epoch, train_steps, checkpoint_dir)
+
+ if train_steps >= args.max_training_steps:
+ logger.info(f"Breaking step loop at {train_steps}.")
+ break
+
+ if train_steps >= args.max_training_steps:
+ logger.info(f"Breaking epoch loop at {epoch}.")
+ break
+
+ dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ if args.seed_all:
+ seed_all(is_gpu=not is_npu_available(), mode=True)
+ # Start
+ main(args)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/train_deepspeed_controlnet.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/train_deepspeed_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4df1b58acd793b21100ee98fc8557ef983b6b51c
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/train_deepspeed_controlnet.py
@@ -0,0 +1,584 @@
+import gc
+import json
+import os
+import random
+import sys
+import time
+from functools import partial
+from glob import glob
+from pathlib import Path
+import numpy as np
+
+import deepspeed
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+from torch.distributed.optim import ZeroRedundancyOptimizer
+from torchvision.transforms import functional as TF
+from diffusers.models import AutoencoderKL
+from transformers import BertModel, BertTokenizer, logging as tf_logging
+
+from hydit.config import get_args
+from hydit.constants import VAE_EMA_PATH, TEXT_ENCODER, TOKENIZER, T5_ENCODER
+from hydit.lr_scheduler import WarmupLR
+from hydit.data_loader.arrow_load_stream import TextImageArrowStream
+from hydit.diffusion import create_diffusion
+from hydit.ds_config import deepspeed_config_from_args
+from hydit.modules.ema import EMA
+from hydit.modules.fp16_layers import Float16Module
+from hydit.modules.models import HUNYUAN_DIT_MODELS
+from hydit.modules.controlnet import HunYuanControlNet
+from hydit.modules.posemb_layers import init_image_posemb
+from hydit.utils.tools import create_logger, set_seeds, create_exp_folder, model_resume, get_trainable_params
+from IndexKits.index_kits import ResolutionGroup
+from IndexKits.index_kits.sampler import DistributedSamplerWithStartIndex, BlockDistributedSampler
+from peft import LoraConfig, get_peft_model
+
+from hydit.annotator.dwpose import DWposeDetector
+torch.optim.lr_scheduler.LRScheduler = torch.optim.lr_scheduler._LRScheduler
+from transformers import pipeline
+import cv2
+from PIL import Image
+
+depth_estimator = pipeline('depth-estimation', device='cuda:{}'.format(int(os.getenv('LOCAL_RANK', '0'))))
+pose_detector = DWposeDetector()
+
+def deepspeed_initialize(args, logger, model, opt, deepspeed_config):
+ logger.info(f"Initialize deepspeed...")
+ logger.info(f" Using deepspeed optimizer")
+
+ def get_learning_rate_scheduler(warmup_min_lr, lr, warmup_num_steps, opt):
+ return WarmupLR(opt, warmup_min_lr, lr, warmup_num_steps)
+
+ logger.info(f" Building scheduler with warmup_min_lr={args.warmup_min_lr}, warmup_num_steps={args.warmup_num_steps}")
+ model, opt, _, scheduler = deepspeed.initialize(
+ model=model,
+ model_parameters=get_trainable_params(model),
+ config_params=deepspeed_config,
+ args=args,
+ lr_scheduler=partial(get_learning_rate_scheduler, args.warmup_min_lr, args.lr, args.warmup_num_steps) if args.warmup_num_steps > 0 else None,
+ )
+ return model, opt, scheduler
+
+def save_checkpoint(args, rank, logger, model, ema, epoch, train_steps, checkpoint_dir):
+ def save_lora_weight(checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt"):
+ cur_ckpt_save_dir = f"{checkpoint_dir}/{tag}"
+ if rank == 0:
+ if args.use_fp16:
+ model.module.module.save_pretrained(cur_ckpt_save_dir)
+ else:
+ model.module.save_pretrained(cur_ckpt_save_dir)
+
+ checkpoint_path = "[Not rank 0. Disabled output.]"
+
+ client_state = {
+ "steps": train_steps,
+ "epoch": epoch,
+ "args": args
+ }
+ if ema is not None:
+ client_state['ema'] = ema.state_dict()
+
+ dst_paths = []
+ if train_steps % args.ckpt_every == 0:
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
+ try:
+ if args.training_parts == "lora":
+ save_lora_weight(checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt")
+ else:
+ model.save_checkpoint(checkpoint_dir, client_state=client_state, tag=f"{train_steps:07d}.pt")
+ dst_paths.append(checkpoint_path)
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
+ except:
+ logger.error(f"Saved failed to {checkpoint_path}")
+
+ if train_steps % args.ckpt_latest_every == 0 or train_steps == args.max_training_steps:
+ save_name = "latest.pt"
+ checkpoint_path = f"{checkpoint_dir}/{save_name}"
+ try:
+ if args.training_parts == "lora":
+ save_lora_weight(checkpoint_dir, client_state, tag=f"{save_name}")
+ else:
+ model.save_checkpoint(checkpoint_dir, client_state=client_state, tag=f"{save_name}")
+ dst_paths.append(checkpoint_path)
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
+ except:
+ logger.error(f"Saved failed to {checkpoint_path}")
+
+ dist.barrier()
+ if rank == 0 and len(dst_paths) > 0:
+ # Delete optimizer states to avoid occupying too much disk space.
+ for dst_path in dst_paths:
+ for opt_state_path in glob(f"{dst_path}/*_00_optim_states.pt"):
+ os.remove(opt_state_path)
+
+ return checkpoint_path
+
+def get_canny(np_img, low_threshold = 100, high_threshold = 200):
+ # tensor = deNormalize()
+ # image = tensor_to_img(tensor)
+ image = cv2.Canny(np_img, low_threshold,high_threshold)
+ image = image[:,:,None]
+ image = np.concatenate([image,image,image], axis=2)
+ ## 输出边缘图像
+ ## 归一化为张量
+ # canny_tensor = img_to_norm_tensor(canny_img)
+ return image
+
+def get_depth(np_img):
+ pil_img = Image.fromarray(np_img)
+ depth = depth_estimator(pil_img)['depth']
+ depth = np.array(depth)
+ depth = depth[:, :, None]
+ depth = np.concatenate([depth, depth, depth], axis=2)
+ return depth
+
+def get_pose(np_img):
+ return pose_detector(np_img)[0]
+
+@torch.no_grad()
+def prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img):
+ image, text_embedding, text_embedding_mask, text_embedding_t5, text_embedding_mask_t5, kwargs = batch
+
+ # clip & mT5 text embedding
+ text_embedding = text_embedding.to(device)
+ text_embedding_mask = text_embedding_mask.to(device)
+ encoder_hidden_states = text_encoder(
+ text_embedding.to(device),
+ attention_mask=text_embedding_mask.to(device),
+ )[0]
+ text_embedding_t5 = text_embedding_t5.to(device).squeeze(1)
+ text_embedding_mask_t5 = text_embedding_mask_t5.to(device).squeeze(1)
+ with torch.no_grad():
+ output_t5 = text_encoder_t5(
+ input_ids=text_embedding_t5,
+ attention_mask=text_embedding_mask_t5 if T5_ENCODER['attention_mask'] else None,
+ output_hidden_states=True
+ )
+ encoder_hidden_states_t5 = output_t5['hidden_states'][T5_ENCODER['layer_index']].detach()
+
+ # additional condition
+ image_meta_size = kwargs['image_meta_size'].to(device)
+ style = kwargs['style'].to(device)
+
+ np_img = image.squeeze(0).add(1).mul(255 / 2).permute(1, 2, 0).cpu().numpy().astype('uint8')
+ if args.control_type == 'canny':
+ condition = get_canny(np_img)
+ elif args.control_type == 'depth':
+ condition = get_depth(np_img)
+ elif args.control_type == 'pose':
+ condition = get_pose(np_img)
+ else:
+ raise NotImplementedError
+ condtion = Image.fromarray(condition)
+ condition = TF.to_tensor(condition)
+ condition = TF.normalize(condition, [0.5], [0.5])
+ condition = condition.unsqueeze(0).to(device)
+
+ if args.extra_fp16:
+ image = image.half()
+ image_meta_size = image_meta_size.half() if image_meta_size is not None else None
+
+ # Map input images to latent space + normalize latents:
+ image = image.to(device)
+ vae_scaling_factor = vae.config.scaling_factor
+ latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor)
+ condition = vae.encode(condition).latent_dist.sample().mul_(vae_scaling_factor)
+
+ # positional embedding
+ _, _, height, width = image.shape
+ reso = f"{height}x{width}"
+ cos_cis_img, sin_cis_img = freqs_cis_img[reso]
+
+ # Model conditions
+ model_kwargs = dict(
+ encoder_hidden_states=encoder_hidden_states,
+ text_embedding_mask=text_embedding_mask,
+ encoder_hidden_states_t5=encoder_hidden_states_t5,
+ text_embedding_mask_t5=text_embedding_mask_t5,
+ image_meta_size=image_meta_size,
+ style=style,
+ cos_cis_img=cos_cis_img,
+ sin_cis_img=sin_cis_img,
+ condition=condition,
+ )
+
+ return latents, model_kwargs
+
+def main(args):
+ if args.training_parts == "lora":
+ args.use_ema = False
+ args.use_ema = False
+
+ assert torch.cuda.is_available(), "Training currently requires at least one GPU."
+
+ dist.init_process_group("nccl")
+ world_size = dist.get_world_size()
+ batch_size = args.batch_size
+ grad_accu_steps = args.grad_accu_steps
+ global_batch_size = world_size * batch_size * grad_accu_steps
+
+ rank = dist.get_rank()
+ device = rank % torch.cuda.device_count()
+ seed = args.global_seed * world_size + rank
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.cuda.set_device(device)
+ print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.")
+ deepspeed_config = deepspeed_config_from_args(args, global_batch_size)
+
+ # Setup an experiment folder
+ experiment_dir, checkpoint_dir, logger = create_exp_folder(args, rank)
+
+ # Log all the arguments
+ logger.info(sys.argv)
+ logger.info(str(args))
+ # Save to a json file
+ args_dict = vars(args)
+ args_dict['world_size'] = world_size
+ with open(f"{experiment_dir}/args.json", 'w') as f:
+ json.dump(args_dict, f, indent=4)
+
+ # Disable the message "Some weights of the model checkpoint at ... were not used when initializing BertModel."
+ # If needed, just comment the following line.
+ tf_logging.set_verbosity_error()
+
+ # ===========================================================================
+ # Building HYDIT
+ # ===========================================================================
+
+ logger.info("Building HYDIT Model.")
+
+ # ---------------------------------------------------------------------------
+ # Training sample base size, such as 256/512/1024. Notice that this size is
+ # just a base size, not necessary the actual size of training samples. Actual
+ # size of the training samples are correlated with `resolutions` when enabling
+ # multi-resolution training.
+ # ---------------------------------------------------------------------------
+ image_size = args.image_size
+ if len(image_size) == 1:
+ image_size = [image_size[0], image_size[0]]
+ if len(image_size) != 2:
+ raise ValueError(f"Invalid image size: {args.image_size}")
+ assert image_size[0] % 8 == 0 and image_size[1] % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder). " \
+ f"got {image_size}"
+ latent_size = [image_size[0] // 8, image_size[1] // 8]
+
+ # initialize model by deepspeed
+ assert args.deepspeed, f"Must enable deepspeed in this script: train_deepspeed.py"
+ with deepspeed.zero.Init(data_parallel_group=torch.distributed.group.WORLD,
+ remote_device=None if args.remote_device == 'none' else args.remote_device,
+ config_dict_or_path=deepspeed_config,
+ mpu=None,
+ enabled=args.zero_stage == 3):
+ model = HUNYUAN_DIT_MODELS[args.model](args,
+ input_size=latent_size,
+ log_fn=logger.info,
+ )
+ controlnet = HunYuanControlNet(args,
+ input_size=latent_size,
+ depth=40, hidden_size=1408, patch_size=2, num_heads=16, mlp_ratio=4.3637,
+ log_fn=logger.info,
+ )
+ # Multi-resolution / Single-resolution training.
+ if args.multireso:
+ resolutions = ResolutionGroup(image_size[0],
+ align=16,
+ step=args.reso_step,
+ target_ratios=args.target_ratios).data
+ else:
+ resolutions = ResolutionGroup(image_size[0],
+ align=16,
+ target_ratios=['1:1']).data
+
+ freqs_cis_img = init_image_posemb(args.rope_img,
+ resolutions=resolutions,
+ patch_size=model.patch_size,
+ hidden_size=model.hidden_size,
+ num_heads=model.num_heads,
+ log_fn=logger.info,
+ rope_real=args.rope_real,
+ )
+
+ # Create EMA model and convert to fp16 if needed.
+ ema = None
+ if args.use_ema:
+ ema = EMA(args, model, device, logger)
+
+ # Setup FP16 main model:
+ if args.use_fp16:
+ model = Float16Module(model, args)
+ controlnet = Float16Module(controlnet, args)
+ logger.info(f" Using main model with data type {'fp16' if args.use_fp16 else 'fp32'}")
+
+ diffusion = create_diffusion(
+ noise_schedule=args.noise_schedule,
+ predict_type=args.predict_type,
+ learn_sigma=args.learn_sigma,
+ mse_loss_weight_type=args.mse_loss_weight_type,
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ noise_offset=args.noise_offset,
+ )
+
+ # Setup VAE
+ logger.info(f" Loading vae from {VAE_EMA_PATH}")
+ vae = AutoencoderKL.from_pretrained(VAE_EMA_PATH)
+ # Setup BERT text encoder
+ logger.info(f" Loading Bert text encoder from {TEXT_ENCODER}")
+ text_encoder = BertModel.from_pretrained(TEXT_ENCODER, False, revision=None)
+ # Setup BERT tokenizer:
+ logger.info(f" Loading Bert tokenizer from {TOKENIZER}")
+ tokenizer = BertTokenizer.from_pretrained(TOKENIZER)
+ # Setup T5 text encoder
+ from hydit.modules.text_encoder import MT5Embedder
+ mt5_path = T5_ENCODER['MT5']
+ embedder_t5 = MT5Embedder(mt5_path, torch_dtype=T5_ENCODER['torch_dtype'], max_length=args.text_len_t5)
+ tokenizer_t5 = embedder_t5.tokenizer
+ text_encoder_t5 = embedder_t5.model
+
+ if args.extra_fp16:
+ logger.info(f" Using fp16 for extra modules: vae, text_encoder")
+ vae = vae.half().to(device)
+ text_encoder = text_encoder.half().to(device)
+ text_encoder_t5 = text_encoder_t5.half().to(device)
+ else:
+ vae = vae.to(device)
+ text_encoder = text_encoder.to(device)
+ text_encoder_t5 = text_encoder_t5.to(device)
+
+ logger.info(f" Optimizer parameters: lr={args.lr}, weight_decay={args.weight_decay}")
+ logger.info(" Using deepspeed optimizer")
+ opt = None
+
+ # ===========================================================================
+ # Building Dataset
+ # ===========================================================================
+
+ logger.info(f"Building Streaming Dataset.")
+ logger.info(f" Loading index file {args.index_file} (v2)")
+
+ dataset = TextImageArrowStream(args=args,
+ resolution=image_size[0],
+ random_flip=args.random_flip,
+ log_fn=logger.info,
+ index_file=args.index_file,
+ multireso=args.multireso,
+ batch_size=batch_size,
+ world_size=world_size,
+ random_shrink_size_cond=args.random_shrink_size_cond,
+ merge_src_cond=args.merge_src_cond,
+ uncond_p=args.uncond_p,
+ text_ctx_len=args.text_len,
+ tokenizer=tokenizer,
+ uncond_p_t5=args.uncond_p_t5,
+ text_ctx_len_t5=args.text_len_t5,
+ tokenizer_t5=tokenizer_t5,
+ )
+ if args.multireso:
+ sampler = BlockDistributedSampler(dataset, num_replicas=world_size, rank=rank, seed=args.global_seed,
+ shuffle=False, drop_last=True, batch_size=batch_size)
+ else:
+ sampler = DistributedSamplerWithStartIndex(dataset, num_replicas=world_size, rank=rank, seed=args.global_seed,
+ shuffle=False, drop_last=True)
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler,
+ num_workers=args.num_workers, pin_memory=True, drop_last=True)
+ logger.info(f" Dataset contains {len(dataset):,} images.")
+ logger.info(f" Index file: {args.index_file}.")
+ if args.multireso:
+ logger.info(f' Using MultiResolutionBucketIndexV2 with step {dataset.index_manager.step} '
+ f'and base size {dataset.index_manager.base_size}')
+ logger.info(f'\n {dataset.index_manager.resolutions}')
+
+ # ===========================================================================
+ # Loading parameter
+ # ===========================================================================
+
+ logger.info(f"Loading parameter")
+ start_epoch = 0
+ start_epoch_step = 0
+ train_steps = 0
+ # Resume checkpoint if needed
+ if args.resume is not None or len(args.resume) > 0:
+ model, ema, start_epoch, start_epoch_step, train_steps = model_resume(args, model, ema, logger, len(loader))
+
+ if args.training_parts == "lora":
+ loraconfig = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ target_modules=args.target_modules
+ )
+ if args.use_fp16:
+ model.module = get_peft_model(model.module, loraconfig)
+ else:
+ model = get_peft_model(model, loraconfig)
+
+ logger.info(f" Training parts: {args.training_parts}")
+
+ if args.use_fp16:
+ controlnet.module.from_dit(model.module)
+ controlnet.module.set_trainable()
+ else:
+ controlnet.from_dit(model)
+ controlnet.set_trainable()
+ logger.info(f" ControlNet loaded from DIT")
+
+
+
+ controlnet, opt, scheduler = deepspeed_initialize(args, logger, controlnet, opt, deepspeed_config)
+
+ # ===========================================================================
+ # Training
+ # ===========================================================================
+
+ model.eval()
+ model.requires_grad_(False)
+ model = model.to(device)
+
+ if args.use_ema:
+ ema.eval()
+
+ print(f" Worker {rank} ready.")
+ dist.barrier()
+
+ iters_per_epoch = len(loader)
+ logger.info(" ****************************** Running training ******************************")
+ logger.info(f" Number GPUs: {world_size}")
+ logger.info(f" Number training samples: {len(dataset):,}")
+ logger.info(f" Number parameters: {sum(p.numel() for p in controlnet.parameters()):,}")
+ logger.info(f" Number trainable params: {sum(p.numel() for p in get_trainable_params(controlnet)):,}")
+ logger.info(" ------------------------------------------------------------------------------")
+ logger.info(f" Iters per epoch: {iters_per_epoch:,}")
+ logger.info(f" Batch size per device: {batch_size}")
+ logger.info(f" Batch size all device: {batch_size * world_size * grad_accu_steps:,} (world_size * batch_size * grad_accu_steps)")
+ logger.info(f" Gradient Accu steps: {args.grad_accu_steps}")
+ logger.info(f" Total optimization steps: {args.epochs * iters_per_epoch // grad_accu_steps:,}")
+
+ logger.info(f" Training epochs: {start_epoch}/{args.epochs}")
+ logger.info(f" Training epoch steps: {start_epoch_step:,}/{iters_per_epoch:,}")
+ logger.info(f" Training total steps: {train_steps:,}/{min(args.max_training_steps, args.epochs * iters_per_epoch):,}")
+ logger.info(" ------------------------------------------------------------------------------")
+ logger.info(f" Noise schedule: {args.noise_schedule}")
+ logger.info(f" Beta limits: ({args.beta_start}, {args.beta_end})")
+ logger.info(f" Learn sigma: {args.learn_sigma}")
+ logger.info(f" Prediction type: {args.predict_type}")
+ logger.info(f" Noise offset: {args.noise_offset}")
+
+ logger.info(" ------------------------------------------------------------------------------")
+ logger.info(f" Using EMA model: {args.use_ema} ({args.ema_dtype})")
+ if args.use_ema:
+ logger.info(f" Using EMA decay: {ema.max_value if args.use_ema else None}")
+ logger.info(f" Using EMA warmup power: {ema.power if args.use_ema else None}")
+ logger.info(f" Using main model fp16: {args.use_fp16}")
+ logger.info(f" Using extra modules fp16: {args.extra_fp16}")
+ logger.info(" ------------------------------------------------------------------------------")
+ logger.info(f" Experiment directory: {experiment_dir}")
+ logger.info(" *******************************************************************************")
+
+ if args.gc_interval > 0:
+ gc.disable()
+ gc.collect()
+
+ # Variables for monitoring/logging purposes:
+ log_steps = 0
+ running_loss = 0
+ start_time = time.time()
+
+ if args.async_ema:
+ ema_stream = torch.cuda.Stream()
+
+ # Training loop
+ for epoch in range(start_epoch, args.epochs):
+ logger.info(f" Start random shuffle with seed={seed}")
+ # Makesure all processors use the same seed to shuffle dataset.
+ dataset.shuffle(seed=args.global_seed + epoch, fast=True)
+ logger.info(f" End of random shuffle")
+
+ # Move sampler to start_index
+ if not args.multireso:
+ start_index = start_epoch_step * world_size * batch_size
+ if start_index != sampler.start_index:
+ sampler.start_index = start_index
+ # Reset start_epoch_step to zero, to ensure next epoch will start from the beginning.
+ start_epoch_step = 0
+ logger.info(f" Iters left this epoch: {len(loader):,}")
+
+ logger.info(f" Beginning epoch {epoch}...")
+ step = 0
+ for batch in loader:
+ step += 1
+
+ latents, model_kwargs = prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img)
+
+ # training model by deepspeed while use fp16
+ if args.use_fp16:
+ if args.use_ema and args.async_ema:
+ with torch.cuda.stream(ema_stream):
+ ema.update(model.module.module, step=step)
+ torch.cuda.current_stream().wait_stream(ema_stream)
+
+ loss_dict = diffusion.training_losses(model=model, x_start=latents, model_kwargs=model_kwargs, controlnet=controlnet)
+ loss = loss_dict["loss"].mean()
+ controlnet.backward(loss)
+ last_batch_iteration = (train_steps + 1) // (global_batch_size // (batch_size * world_size))
+ controlnet.step(lr_kwargs={'last_batch_iteration': last_batch_iteration})
+
+ if args.use_ema and not args.async_ema or (args.async_ema and step == len(loader)-1):
+ if args.use_fp16:
+ ema.update(model.module.module, step=step)
+ else:
+ ema.update(model.module, step=step)
+
+ # ===========================================================================
+ # Log loss values:
+ # ===========================================================================
+ running_loss += loss.item()
+ log_steps += 1
+ train_steps += 1
+ if train_steps % args.log_every == 0:
+ # Measure training speed:
+ torch.cuda.synchronize()
+ end_time = time.time()
+ steps_per_sec = log_steps / (end_time - start_time)
+ # Reduce loss history over all processes:
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
+ avg_loss = avg_loss.item() / world_size
+ # get lr from deepspeed fused optimizer
+ logger.info(f"(step={train_steps:07d}) " +
+ (f"(update_step={train_steps // args.grad_accu_steps:07d}) " if args.grad_accu_steps > 1 else "") +
+ f"Train Loss: {avg_loss:.4f}, "
+ f"Lr: {opt.param_groups[0]['lr']:.6g}, "
+ f"Steps/Sec: {steps_per_sec:.2f}, "
+ f"Samples/Sec: {int(steps_per_sec * batch_size * world_size):d}")
+ # Reset monitoring variables:
+ running_loss = 0
+ log_steps = 0
+ start_time = time.time()
+
+ # collect gc:
+ if args.gc_interval > 0 and (step % args.gc_interval == 0):
+ gc.collect()
+
+ if (train_steps % args.ckpt_every == 0 or train_steps % args.ckpt_latest_every == 0 # or train_steps == args.max_training_steps
+ ) and train_steps > 0:
+ save_checkpoint(args, rank, logger, controlnet, ema, epoch, train_steps, checkpoint_dir)
+
+ if train_steps >= args.max_training_steps:
+ logger.info(f"Breaking step loop at {train_steps}.")
+ break
+
+ if train_steps >= args.max_training_steps:
+ logger.info(f"Breaking epoch loop at {epoch}.")
+ break
+
+ dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+ # Start
+ main(get_args())
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/hydit/utils/tools.py b/PyTorch/built-in/mlm/HunyuanDiT/hydit/utils/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d829227fbd6c9ee3d0f11ea873a5ac7abd70e1
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/hydit/utils/tools.py
@@ -0,0 +1,190 @@
+import random
+import logging
+from pathlib import Path
+import shutil
+
+import numpy as np
+from PIL import Image
+import torch
+import torch.distributed as dist
+from tqdm.auto import tqdm
+import math
+import torch.nn.functional as F
+import os
+
+def get_trainable_params(model):
+ params = model.parameters()
+ params = [p for p in params if p.requires_grad]
+ return params
+
+
+def set_seeds(seed_list, device=None):
+ if isinstance(seed_list, (tuple, list)):
+ seed = sum(seed_list)
+ else:
+ seed = seed_list
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ return torch.Generator(device).manual_seed(seed)
+
+def get_start_epoch(resume_path, ckpt, steps_per_epoch):
+ if 'epoch' in ckpt:
+ start_epoch = ckpt['epoch']
+ else:
+ start_epoch = 0
+ if 'steps' in ckpt:
+ train_steps = ckpt['steps']
+ else:
+ try:
+ train_steps = int(Path(resume_path).stem)
+ except:
+ train_steps = start_epoch * steps_per_epoch
+
+ start_epoch_step = train_steps % steps_per_epoch + 1
+ return start_epoch, start_epoch_step, train_steps
+
+def assert_shape(*args):
+ if len(args) < 2:
+ return
+ cond = True
+ fail_str = f"{args[0] if isinstance(args[0], (list, tuple)) else args[0].shape}"
+ for i in range(1, len(args)):
+ shape1 = args[i] if isinstance(args[i], (list, tuple)) else args[i].shape
+ shape2 = args[i - 1] if isinstance(args[i - 1], (list, tuple)) else args[i - 1].shape
+ cond = cond and (shape1 == shape2)
+ fail_str += f" vs {args[i] if isinstance(args[i], (list, tuple)) else args[i].shape}"
+ assert cond, fail_str
+
+
+def create_logger(logging_dir=None, logging_file=None, ddp=True):
+ """
+ Create a logger that writes to a log file and stdout.
+ """
+ if not ddp or (ddp and dist.get_rank() == 0): # real logger
+ if logging_file is not None:
+ file_handler = [logging.FileHandler(logging_file)]
+ elif logging_dir is not None:
+ file_handler = [logging.FileHandler(f"{logging_dir}/log.txt")]
+ else:
+ file_handler = []
+ logging.basicConfig(
+ level=logging.INFO,
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ handlers=[logging.StreamHandler()] + file_handler
+ )
+ logger = logging.getLogger(__name__)
+ else:
+ logger = logging.getLogger(__name__)
+ logger.addHandler(logging.NullHandler())
+ return logger
+
+def create_exp_folder(args, rank):
+ if rank == 0:
+ os.makedirs(args.results_dir, exist_ok=True)
+ existed_experiments = list(Path(args.results_dir).glob("*dit*"))
+ if len(existed_experiments) == 0:
+ experiment_index = 1
+ else:
+ existed_experiments.sort()
+ print('existed_experiments', existed_experiments)
+ experiment_index = max([int(x.stem.split('-')[0]) for x in existed_experiments]) + 1
+ dist.barrier()
+ model_string_name = args.task_flag if args.task_flag else args.model.replace("/", "-")
+ experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder
+ checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
+ if rank == 0:
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ logger = create_logger(experiment_dir)
+ logger.info(f"Experiment directory created at {experiment_dir}")
+ else:
+ logger = create_logger()
+ experiment_dir = ""
+
+ return experiment_dir, checkpoint_dir, logger
+
+
+def model_resume(args, model, ema, logger, len_loader):
+ """
+ Load pretrained weights.
+ """
+ start_epoch = 0
+ start_epoch_step = 0
+ train_steps = 0
+ resume_path = args.resume
+ if not Path(resume_path).exists():
+ raise FileNotFoundError(f" Cannot find checkpoint from {resume_path}")
+
+ logger.info(f" Resume deepspeed={args.resume_deepspeed}, "
+ f"Resume split={args.resume_split}, "
+ f"Resume from checkpoint {resume_path}")
+ # Resume model and ema states (not include optimizer states) from a checkpoint saved by Deepspeed version of DIT.
+ if args.resume_deepspeed:
+ assert 'mp_rank_00_model_states.pt' in os.listdir(resume_path), f' Cannot find dp chkpt from {resume_path}'
+ resume_ckpt = torch.load(os.path.join(resume_path, 'mp_rank_00_model_states.pt'),
+ map_location=lambda storage, loc: storage)
+ # Resume main model
+ if args.ema_to_module:
+ logger.info(" Resume main model from the ema states.")
+ model.load_state_dict(resume_ckpt['ema'], strict=args.strict)
+ else:
+ logger.info(" Resume main model from the main states.")
+ model.load_state_dict(resume_ckpt['module'], strict=args.strict)
+ # Resume EMA model
+ if args.use_ema:
+ if args.module_to_ema:
+ logger.info(" Resume EMA model from the main states.")
+ ema.load_state_dict(resume_ckpt['module'], strict=args.strict)
+ else:
+ logger.info(" Resume EMA model from the EMA states.")
+ ema.load_state_dict(resume_ckpt['ema'], strict=args.strict)
+ if not args.reset_loader:
+ start_epoch, start_epoch_step, train_steps = get_start_epoch(args.resume, resume_ckpt, len_loader)
+ # Resume model and ema states (not include optimizer states) from two checkpoints separated from DeepSpeed ckpt.
+ elif args.resume_split:
+ # Resume main model
+ if args.ema_to_module:
+ assert 'pytorch_model_ema.pt' in os.listdir(
+ resume_path), f' Cannot find pytorch_model_ema.pt from {resume_path}'
+ logger.info(f" Resume main model from ema states.")
+ resume_ckpt_ema = torch.load(os.path.join(resume_path, 'pytorch_model_ema.pt'),
+ map_location=lambda storage, loc: storage)
+ model.load_state_dict(resume_ckpt_ema, strict=args.strict)
+ else:
+ assert 'pytorch_model_module.pt' in os.listdir(
+ resume_path), f' Cannot find pytorch_model_module.pt from {resume_path}'
+ logger.info(f" Resume main model from main states.")
+ resume_ckpt_module = torch.load(os.path.join(resume_path, 'pytorch_model_module.pt'),
+ map_location=lambda storage, loc: storage)
+ model.load_state_dict(resume_ckpt_module, strict=args.strict)
+ # Resume ema model
+ if args.use_ema:
+ if args.module_to_ema:
+ if "resume_ckpt_module" in locals():
+ logger.info(f" Resume ema model from main states.")
+ ema.load_state_dict(resume_ckpt_module, strict=args.strict)
+ else:
+ assert 'pytorch_model_module.pt' in os.listdir(
+ resume_path), f' Cannot find pytorch_model_module.pt from {resume_path}'
+ logger.info(f" Resume ema model from module states.")
+ resume_ckpt_module = torch.load(os.path.join(resume_path, 'pytorch_model_module.pt'),
+ map_location=lambda storage, loc: storage)
+ ema.load_state_dict(resume_ckpt_module, strict=args.strict)
+ else:
+ if "resume_ckpt_ema" in locals():
+ logger.info(f" Resume ema model from EMA states.")
+ ema.load_state_dict(resume_ckpt_ema, strict=args.strict)
+ else:
+ assert 'pytorch_model_ema.pt' in os.listdir(
+ resume_path), f' Cannot find pytorch_model_ema.pt from {resume_path}'
+ logger.info(f" Resume ema model from EMA states.")
+ resume_ckpt_ema = torch.load(os.path.join(resume_path, 'pytorch_model_ema.pt'),
+ map_location=lambda storage, loc: storage)
+ ema.load_state_dict(resume_ckpt_ema, strict=args.strict)
+ else:
+ raise ValueError(" “If `resume` is True, then either `resume_split` or `resume_deepspeed` must be true.”")
+
+ return model, ema, start_epoch, start_epoch_step, train_steps
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/lite/README.md b/PyTorch/built-in/mlm/HunyuanDiT/lite/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d9f52703dcb587aa219cf837178fa9506693db9a
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/lite/README.md
@@ -0,0 +1,35 @@
+## Using HunyuanDiT Inference with under 6GB GPU VRAM
+
+### Instructions
+Running HunyuanDiT in under 6GB GPU VRAM is available now based on [**diffusers**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuandit). Here we provide instructions and demo for your quick start.
+
+The 6Glite version supports Nvidia Ampere architecture series graphics cards such as RTX 3070/3080/4080/4090, A100, and so on.
+
+The only thing you need do is to install the following library:
+
+```bash
+pip install -U bitsandbytes
+pip install git+https://github.com/huggingface/diffusers
+pip install torch==2.0.0
+```
+
+Then you can enjoy your HunyuanDiT text-to-image journey under 6GB GPU VRAM directly!
+
+Here is a demo for you.
+
+```bash
+cd HunyuanDiT
+
+# Quick start
+model_id=Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled
+prompt=一个宇航员在骑马
+infer_steps=50
+guidance_scale=6
+python3 lite/inference.py ${model_id} ${prompt} ${infer_steps} ${guidance_scale}
+```
+
+Note: To use other features in hydit requires torch 1.13.1. In this case, you may need to downgrade your torch version.
+
+```bash
+pip install torch==1.13.1
+```
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/lite/inference.py b/PyTorch/built-in/mlm/HunyuanDiT/lite/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..40f8bcd0bcc8cb8614124c612637f73e9364c7db
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/lite/inference.py
@@ -0,0 +1,175 @@
+import random
+import torch
+from diffusers import HunyuanDiTPipeline
+from transformers import T5EncoderModel
+import time
+from loguru import logger
+import gc
+import sys
+
+NEGATIVE_PROMPT = ''
+
+TEXT_ENCODER_CONF = {
+ "negative_prompt": NEGATIVE_PROMPT,
+ "prompt_embeds": None,
+ "negative_prompt_embeds": None,
+ "prompt_attention_mask": None,
+ "negative_prompt_attention_mask": None,
+ "max_sequence_length": 256,
+ "text_encoder_index": 1,
+}
+
+def flush():
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
+class End2End(object):
+ def __init__(self, model_id="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled"):
+ self.model_id = model_id
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ # ========================================================================
+ self.default_negative_prompt = NEGATIVE_PROMPT
+ logger.info("==================================================")
+ logger.info(f" Model is ready. ")
+ logger.info("==================================================")
+
+ def load_pipeline(self):
+ self.pipeline= HunyuanDiTPipeline.from_pretrained(
+ self.model_id,
+ text_encoder=None,
+ text_encoder_2=None,
+ torch_dtype=torch.float16,
+ ).to(self.device)
+
+
+ def get_text_emb(self, prompts):
+ with torch.no_grad():
+ text_encoder_2 = T5EncoderModel.from_pretrained(
+ self.model_id,
+ subfolder="text_encoder_2",
+ load_in_8bit=True,
+ device_map="auto",
+ )
+ encoder_pipeline = HunyuanDiTPipeline.from_pretrained(
+ self.model_id,
+ text_encoder_2=text_encoder_2,
+ transformer=None,
+ vae=None,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+ )
+ TEXT_ENCODER_CONF["negative_prompt"]=self.default_negative_prompt
+ prompt_emb1 = encoder_pipeline.encode_prompt(prompts, negative_prompt=self.default_negative_prompt)
+ prompt_emb2 = encoder_pipeline.encode_prompt(prompts, **TEXT_ENCODER_CONF)
+ del text_encoder_2
+ del encoder_pipeline
+ flush()
+ return prompt_emb1, prompt_emb2
+
+ def predict(self,
+ user_prompt,
+ seed=None,
+ enhanced_prompt=None,
+ negative_prompt=None,
+ infer_steps=50,
+ guidance_scale=6,
+ batch_size=1,
+ ):
+ # ========================================================================
+ # Arguments: seed
+ # ========================================================================
+ if seed is None:
+ seed = random.randint(0, 1_000_000)
+ if not isinstance(seed, int):
+ raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
+ generator = torch.Generator(device=self.device).manual_seed(seed)
+
+ # ========================================================================
+ # Arguments: prompt, new_prompt, negative_prompt
+ # ========================================================================
+ if not isinstance(user_prompt, str):
+ raise TypeError(f"`user_prompt` must be a string, but got {type(user_prompt)}")
+ user_prompt = user_prompt.strip()
+ prompt = user_prompt
+
+ if enhanced_prompt is not None:
+ if not isinstance(enhanced_prompt, str):
+ raise TypeError(f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}")
+ enhanced_prompt = enhanced_prompt.strip()
+ prompt = enhanced_prompt
+
+ # negative prompt
+ if negative_prompt is not None and negative_prompt != '':
+ self.default_negative_prompt = negative_prompt
+ if not isinstance(self.default_negative_prompt, str):
+ raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}")
+
+
+ # ========================================================================
+
+ logger.debug(f"""
+ prompt: {user_prompt}
+ enhanced prompt: {enhanced_prompt}
+ seed: {seed}
+ negative_prompt: {negative_prompt}
+ batch_size: {batch_size}
+ guidance_scale: {guidance_scale}
+ infer_steps: {infer_steps}
+ """)
+
+
+ # get text embeding
+ flush()
+ prompt_emb1, prompt_emb2 = self.get_text_emb(prompt)
+ prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask = prompt_emb1
+ prompt_embeds_2,negative_prompt_embeds_2,prompt_attention_mask_2,negative_prompt_attention_mask_2 = prompt_emb2
+ del prompt_emb1
+ del prompt_emb2
+ # get pipeline
+ self.load_pipeline()
+ samples = self.pipeline(
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_2=prompt_embeds_2,
+ negative_prompt_embeds=negative_prompt_embeds,
+ negative_prompt_embeds_2=negative_prompt_embeds_2,
+ prompt_attention_mask=prompt_attention_mask,
+ prompt_attention_mask_2=prompt_attention_mask_2,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ negative_prompt_attention_mask_2=negative_prompt_attention_mask_2,
+ num_images_per_prompt=batch_size,
+ guidance_scale=guidance_scale,
+ num_inference_steps=infer_steps,
+ generator=generator,
+ ).images[0]
+
+ return {
+ 'images': samples,
+ 'seed': seed,
+ }
+
+
+if __name__ == "__main__":
+
+ if len(sys.argv) != 5:
+ print("Usage: python lite/inference.py ${model_id} ${prompt} ${infer_steps} ${guidance_scale}")
+ print("model_id: Choose a diffusers repository from the official Hugging Face repository https://huggingface.co/Tencent-Hunyuan, "
+ "such as Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers, "
+ "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled, "
+ "Tencent-Hunyuan/HunyuanDiT-Diffusers, or Tencent-Hunyuan/HunyuanDiT-Diffusers-Distilled.")
+ print("prompt: the input prompt")
+ print("infer_steps: infer_steps")
+ print("guidance_scale: guidance_scale")
+ sys.exit(1)
+ model_id = sys.argv[1]
+ prompt = sys.argv[2]
+ infer_steps = int(sys.argv[3])
+ guidance_scale = int(sys.argv[4])
+ gen = End2End(model_id)
+ seed = 42
+ results = gen.predict(prompt,
+ seed = seed,
+ infer_steps=infer_steps,
+ guidance_scale=guidance_scale,
+ )
+ results['images'].save('./lite_image.png')
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/lora/README.md b/PyTorch/built-in/mlm/HunyuanDiT/lora/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ef549842ad4802de4b8d1f24e64c5807d39b9842
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/lora/README.md
@@ -0,0 +1,217 @@
+
+## Using LoRA to fine-tune HunyuanDiT
+
+
+### Instructions
+
+ The dependencies and installation are basically the same as the [**base model**](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1).
+
+ We provide two types of trained LoRA weights for you to test.
+
+ Then download the model using the following commands:
+
+```bash
+cd HunyuanDiT
+# Use the huggingface-cli tool to download the model.
+huggingface-cli download Tencent-Hunyuan/HYDiT-LoRA --local-dir ./ckpts/t2i/lora
+
+# Quick start
+python sample_t2i.py --prompt "青花瓷风格,一只猫在追蝴蝶" --no-enhance --load-key ema --lora-ckpt ./ckpts/t2i/lora/porcelain
+```
+
+Examples of training data and inference results are as follows:
+
+
+
Examples of training data
+
+
+
+
+
+
+
+
+
+
青花瓷风格,一只蓝色的鸟儿站在蓝色的花瓶上,周围点缀着白色花朵,背景是白色 (Porcelain style, a blue bird stands on a blue vase, surrounded by white flowers, with a white background.
+)
+
青花瓷风格,这是一幅蓝白相间的陶瓷盘子,上面描绘着一只狐狸和它的幼崽在森林中漫步,背景是白色 (Porcelain style, this is a blue and white ceramic plate depicting a fox and its cubs strolling in the forest, with a white background.)
+
青花瓷风格,在黑色背景上,一只蓝色的狼站在蓝白相间的盘子上,周围是树木和月亮 (Porcelain style, on a black background, a blue wolf stands on a blue and white plate, surrounded by trees and the moon.)
+
青花瓷风格,在蓝色背景上,一只蓝色蝴蝶和白色花朵被放置在中央 (Porcelain style, on a blue background, a blue butterfly and white flowers are placed in the center.)
+
+
+
Examples of inference results
+
+
+
+
+
+
+
+
+
青花瓷风格,苏州园林 (Porcelain style, Suzhou Gardens.)
+
青花瓷风格,一朵荷花 (Porcelain style, a lotus flower.)
+
青花瓷风格,一只羊(Porcelain style, a sheep.)
+
青花瓷风格,一个女孩在雨中跳舞(Porcelain style, a girl dancing in the rain.)
+
+
+
+
+
+### Training
+
+We provide three types of weights for fine-tuning LoRA, `ema`, `module` and `distill`, and you can choose according to the actual effect. By default, we use `ema` weights.
+
+Here is an example, we load the `ema` weights into the main model and perform LoRA fine-tuning through the `--ema-to-module` parameter.
+
+If you want to load the `module` weights into the main model, just remove the `--ema-to-module` parameter.
+
+If multiple resolution are used, you need to add the `--multireso` and `--reso-step 64 ` parameter.
+
+```bash
+model='DiT-g/2' # model type
+task_flag="lora_porcelain_ema_rank64" # task flag
+resume=./ckpts/t2i/model/ # resume checkpoint
+index_file=dataset/porcelain/jsons/porcelain.json # the selected data indices
+results_dir=./log_EXP # save root for results
+batch_size=1 # training batch size
+image_size=1024 # training image resolution
+grad_accu_steps=2 # gradient accumulation steps
+warmup_num_steps=0 # warm-up steps
+lr=0.0001 # learning rate
+ckpt_every=100 # create a ckpt every a few steps.
+ckpt_latest_every=2000 # create a ckpt named `latest.pt` every a few steps.
+rank=64 # rank of lora
+max_training_steps=2000 # Maximum training iteration steps
+
+PYTHONPATH=./ deepspeed hydit/train_deepspeed.py \
+ --task-flag ${task_flag} \
+ --model ${model} \
+ --training_parts lora \
+ --rank ${rank} \
+ --resume-split \
+ --resume ${resume} \
+ --ema-to-module \
+ --lr ${lr} \
+ --noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.03 \
+ --predict-type v_prediction \
+ --uncond-p 0.44 \
+ --uncond-p-t5 0.44 \
+ --index-file ${index_file} \
+ --random-flip \
+ --batch-size ${batch_size} \
+ --image-size ${image_size} \
+ --global-seed 999 \
+ --grad-accu-steps ${grad_accu_steps} \
+ --warmup-num-steps ${warmup_num_steps} \
+ --use-flash-attn \
+ --use-fp16 \
+ --ema-dtype fp32 \
+ --results-dir ${results_dir} \
+ --ckpt-every ${ckpt_every} \
+ --max-training-steps ${max_training_steps}\
+ --ckpt-latest-every ${ckpt_latest_every} \
+ --log-every 10 \
+ --deepspeed \
+ --deepspeed-optimizer \
+ --use-zero-stage 2 \
+ --qk-norm \
+ --rope-img base512 \
+ --rope-real \
+ "$@"
+```
+
+Recommended parameter settings
+
+| Parameter | Description | Recommended Parameter Value | Note|
+|:---------------:|:---------:|:---------------------------------------------------:|:--:|
+| `--batch-size` | Training batch size | 1 | Depends on GPU memory|
+| `--grad-accu-steps` | Size of gradient accumulation | 2 | - |
+| `--rank` | Rank of lora | 64 | Choosing from 8-128 |
+| `--max-training-steps` | Training steps | 2000 | Depend on training data size, for reference apply 2000 steps on 100 images|
+| `--lr` | Learning rate | 0.0001 | - |
+
+
+### Inference
+
+After the training is complete, you can use the following command line for inference.
+We provide the `--lora-ckpt` parameter for selecting the folder which contains lora weights and configurations.
+
+a. Using LoRA during inference
+
+```bash
+python sample_t2i.py --prompt "青花瓷风格,一只小狗" --no-enhance --lora-ckpt log_EXP/001-lora_porcelain_ema_rank64/checkpoints/0001000.pt
+```
+
+b. Using LoRA in gradio
+```bash
+python app/hydit_app.py --infer-mode fa --no-enhance --lora-ckpt log_EXP/001-lora_porcelain_ema_rank64/checkpoints/0001000.pt
+```
+
+c. Merge LoRA weights into the main model
+
+We provide the `--output-merge-path` parameter to set the path for saving the merged weights.
+
+```bash
+PYTHONPATH=./ python lora/merge.py --lora-ckpt log_EXP/001-lora_porcelain_ema_rank64/checkpoints/0001000.pt --output-merge-path ./ckpts/t2i/model/pytorch_model_merge.pt
+```
+
+d. Regarding how to use the LoRA weights we trained in diffusion, we provide the following script. To ensure compatibility with the diffuser, some modifications are made, which means that LoRA cannot be directly loaded.
+
+
+```python
+import torch
+from diffusers import HunyuanDiTPipeline
+
+num_layers = 40
+def load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale):
+ for i in range(num_layers):
+ Wqkv = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_A.weight"])
+ q, k, v = torch.chunk(Wqkv, 3, dim=0)
+ transformer_state_dict[f"blocks.{i}.attn1.to_q.weight"] += lora_scale * q
+ transformer_state_dict[f"blocks.{i}.attn1.to_k.weight"] += lora_scale * k
+ transformer_state_dict[f"blocks.{i}.attn1.to_v.weight"] += lora_scale * v
+
+ out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_A.weight"])
+ transformer_state_dict[f"blocks.{i}.attn1.to_out.0.weight"] += lora_scale * out_proj
+
+ q_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_A.weight"])
+ transformer_state_dict[f"blocks.{i}.attn2.to_q.weight"] += lora_scale * q_proj
+
+ kv_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_A.weight"])
+ k, v = torch.chunk(kv_proj, 2, dim=0)
+ transformer_state_dict[f"blocks.{i}.attn2.to_k.weight"] += lora_scale * k
+ transformer_state_dict[f"blocks.{i}.attn2.to_v.weight"] += lora_scale * v
+
+ out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_A.weight"])
+ transformer_state_dict[f"blocks.{i}.attn2.to_out.0.weight"] += lora_scale * out_proj
+
+ q_proj = torch.matmul(lora_state_dict["pooler.q_proj.lora_B.weight"], lora_state_dict["pooler.q_proj.lora_A.weight"])
+ transformer_state_dict["time_extra_emb.pooler.q_proj.weight"] += lora_scale * q_proj
+
+ return transformer_state_dict
+
+pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", torch_dtype=torch.float16)
+pipe.to("cuda")
+
+from safetensors import safe_open
+
+lora_state_dict = {}
+with safe_open("./ckpts/t2i/lora/jade/adapter_model.safetensors", framework="pt", device=0) as f:
+ for k in f.keys():
+ lora_state_dict[k[17:]] = f.get_tensor(k) # remove 'basemodel.model'
+
+transformer_state_dict = pipe.transformer.state_dict()
+transformer_state_dict = load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale=1.0)
+pipe.transformer.load_state_dict(transformer_state_dict)
+
+prompt = "玉石绘画风格,一只猫在追蝴蝶"
+image = pipe(
+ prompt,
+ num_inference_steps=100,
+ guidance_scale=6.0,
+).images[0]
+image.save('img.png')
+```
+
+
+e. For more information, please refer to [HYDiT-LoRA](https://huggingface.co/Tencent-Hunyuan/HYDiT-LoRA).
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/lora/merge.py b/PyTorch/built-in/mlm/HunyuanDiT/lora/merge.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fd7dc664aaaaf6412ea549080501aaea6e80a69
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/lora/merge.py
@@ -0,0 +1,28 @@
+import torch
+import os
+from hydit.config import get_args
+from hydit.modules.models import HUNYUAN_DIT_MODELS
+
+from hydit.inference import _to_tuple
+
+args = get_args()
+
+image_size = _to_tuple(args.image_size)
+latent_size = (image_size[0] // 8, image_size[1] // 8)
+
+model = HUNYUAN_DIT_MODELS[args.model](args,
+ input_size=latent_size,
+ log_fn=print,
+ )
+model_path = os.path.join(args.model_root, 't2i', 'model', f"pytorch_model_{args.load_key}.pt")
+state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
+
+print(f"Loading model from {model_path}")
+model.load_state_dict(state_dict)
+
+print(f"Loading lora from {args.lora_ckpt}")
+model.load_adapter(args.lora_ckpt)
+model.merge_and_unload()
+
+torch.save(model.state_dict(), args.output_merge_path)
+print(f"Model saved to {args.output_merge_path}")
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/lora/train_lora.sh b/PyTorch/built-in/mlm/HunyuanDiT/lora/train_lora.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eadbbe1d0382ca6cf632b46f51f83671594ea78f
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/lora/train_lora.sh
@@ -0,0 +1,50 @@
+model='DiT-g/2' # model type
+task_flag="lora_porcelain_ema_rank64" # task flag
+resume=./ckpts/t2i/model/ # resume checkpoint
+index_file=dataset/porcelain/jsons/porcelain.json # the selected data indices
+results_dir=./log_EXP # save root for results
+batch_size=1 # training batch size
+image_size=1024 # training image resolution
+grad_accu_steps=2 # gradient accumulation steps
+warmup_num_steps=0 # warm-up steps
+lr=0.0001 # learning rate
+ckpt_every=100 # create a ckpt every a few steps.
+ckpt_latest_every=2000 # create a ckpt named `latest.pt` every a few steps.
+rank=64 # rank of lora
+max_training_steps=2000 # Maximum training iteration steps
+
+PYTHONPATH=./ deepspeed hydit/train_deepspeed.py \
+ --task-flag ${task_flag} \
+ --model ${model} \
+ --training-parts lora \
+ --rank ${rank} \
+ --resume-split \
+ --resume ${resume} \
+ --ema-to-module \
+ --lr ${lr} \
+ --noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.03 \
+ --predict-type v_prediction \
+ --uncond-p 0.44 \
+ --uncond-p-t5 0.44 \
+ --index-file ${index_file} \
+ --random-flip \
+ --batch-size ${batch_size} \
+ --image-size ${image_size} \
+ --global-seed 999 \
+ --grad-accu-steps ${grad_accu_steps} \
+ --warmup-num-steps ${warmup_num_steps} \
+ --use-flash-attn \
+ --use-fp16 \
+ --ema-dtype fp32 \
+ --results-dir ${results_dir} \
+ --ckpt-every ${ckpt_every} \
+ --max-training-steps ${max_training_steps}\
+ --ckpt-latest-every ${ckpt_latest_every} \
+ --log-every 10 \
+ --deepspeed \
+ --deepspeed-optimizer \
+ --use-zero-stage 2 \
+ --qk-norm \
+ --rope-img base512 \
+ --rope-real \
+ "$@"
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/README.md b/PyTorch/built-in/mlm/HunyuanDiT/mllm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..860245d590987b9bb6807a59b575d0441e202dd6
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/README.md
@@ -0,0 +1,109 @@
+# Hunyuan-MLLM
+We provide two multimodal large language models, Hunyuan-Captioner and Dialogen. The former provides fine-grained text descriptions for training data, while the latter enhances the user's prompt input during inference and supports multi-turn text-to-image generation.
+
+## Contents
+- [Hunyuan-Captioner](https://github.com/Tencent/HunyuanDiT/tree/main/mllm#hunyuan-captioner)
+ - [Instructions](https://github.com/Tencent/HunyuanDiT/tree/main/mllm#instructions)
+ - [Examples](https://github.com/Tencent/HunyuanDiT/tree/main/mllm#examples)
+ - [Inference](https://github.com/Tencent/HunyuanDiT/tree/main/mllm#inference)
+ - [Gradio](https://github.com/Tencent/HunyuanDiT/tree/main/mllm#gradio)
+- [DialogGen](https://github.com/Tencent/HunyuanDiT/tree/main/mllm#dialoggen)
+ - [Inference](https://github.com/Tencent/HunyuanDiT/tree/main/mllm#inference-1)
+ - [Gradio](https://github.com/Tencent/HunyuanDiT/tree/main/mllm#gradio-1)
+
+## Hunyuan-Captioner
+Hunyuan-Captioner meets the need of text-to-image techniques by maintaining a high degree of image-text consistency. It can generate high-quality image descriptions from a variety of angles, including object description, objects relationships, background information, image style, etc. Our code is based on [LLaVA](https://github.com/haotian-liu/LLaVA) implementation.
+
+### Examples
+
+
+
+
+### Instructions
+a. Install dependencies
+
+The dependencies and installation are basically the same as the [**base model**](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1).
+
+b. Data download
+```shell
+cd HunyuanDiT
+wget -O ./dataset/data_demo.zip https://dit.hunyuan.tencent.com/download/HunyuanDiT/data_demo.zip
+unzip ./dataset/data_demo.zip -d ./dataset
+mkdir ./dataset/porcelain/arrows ./dataset/porcelain/jsons
+```
+
+c. Model download
+```shell
+# Use the huggingface-cli tool to download the model.
+huggingface-cli download Tencent-Hunyuan/HunyuanCaptioner --local-dir ./ckpts/captioner
+```
+
+### Inference
+
+Current supported prompt templates:
+
+|Mode | Prompt template |Description |
+| --- | --- | --- |
+|caption_zh | 描述这张图片 |Caption in Chinese |
+|insert_content | 根据提示词“{}”,描述这张图片 |Insert specific knowledge into caption|
+|caption_en | Please describe the content of this image |Caption in English |
+| | | |
+
+
+a. Single picture inference in Chinese
+
+```bash
+python mllm/caption_demo.py --mode "caption_zh" --image_file "mllm/images/demo1.png" --model_path "./ckpts/captioner"
+```
+
+b. Insert specific knowledge into caption
+
+```bash
+python mllm/caption_demo.py --mode "insert_content" --content "宫保鸡丁" --image_file "mllm/images/demo2.png" --model_path "./ckpts/captioner"
+```
+
+c. Single picture inference in English
+
+```bash
+python mllm/caption_demo.py --mode "caption_en" --image_file "mllm/images/demo3.png" --model_path "./ckpts/captioner"
+```
+
+d. Multiple pictures inference in Chinese
+
+```bash
+### Convert multiple pictures to csv file.
+python mllm/make_csv.py --img_dir "mllm/images" --input_file "mllm/images/demo.csv"
+
+### Multiple pictures inference
+python mllm/caption_demo.py --mode "caption_zh" --input_file "mllm/images/demo.csv" --output_file "mllm/images/demo_res.csv" --model_path "./ckpts/captioner"
+```
+
+(Optional) To convert the output csv file to Arrow format, please refer to [Data Preparation #3](#data-preparation) for detailed instructions.
+
+### Gradio
+To launch a Gradio demo locally, please execute the following commands sequentially. Ensure each command is running in the background. For more detailed instructions, please refer to [LLaVA](https://github.com/haotian-liu/LLaVA).
+```bash
+cd mllm
+python -m llava.serve.controller --host 0.0.0.0 --port 10000
+python -m llava.serve.gradio_web_server --controller http://0.0.0.0:10000 --model-list-mode reload --port 443
+python -m llava.serve.model_worker --host 0.0.0.0 --controller http://0.0.0.0:10000 --port 40000 --worker http://0.0.0.0:40000 --model-path "../ckpts/captioner" --model-name LlavaMistral
+```
+Then the demo can be accessed through http://0.0.0.0:443. It should be noted that the 0.0.0.0 here needs to be X.X.X.X with your server IP.
+
+
+
+
+## Hunyuan-DialogGen
+We additionally provide inference commands for [DialogGen](https://github.com/Centaurusalpha/DialogGen).
+### Inference
+```bash
+cd HunyuanDiT
+python mllm/dialoggen_demo.py --prompt "画一只小猫"
+```
+
+### Gradio
+```bash
+# Start a multi-turn T2I generation UI.
+# If your GPU memory is less than 32GB, use '--load-4bit' to enable 4-bit quantization, which requires at least 22GB of memory.
+python app/multiTurnT2I_app.py
+```
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/caption_demo.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/caption_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5d823cf04ab72cb9ec0c6f63c7586a45fb294af
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/caption_demo.py
@@ -0,0 +1,182 @@
+import argparse
+import torch
+import sys
+import os
+import pandas as pd
+import tqdm
+# 添加当前命令行运行的目录到 sys.path
+sys.path.append(os.getcwd()+"/mllm")
+
+
+from llava.constants import (
+ IMAGE_TOKEN_INDEX,
+ DEFAULT_IMAGE_TOKEN,
+ DEFAULT_IM_START_TOKEN,
+ DEFAULT_IM_END_TOKEN,
+ IMAGE_PLACEHOLDER,
+)
+from llava.conversation import conv_templates, SeparatorStyle
+from llava.model.builder import load_pretrained_model
+from llava.utils import disable_torch_init
+from llava.mm_utils import (
+ process_images,
+ tokenizer_image_token,
+ get_model_name_from_path,
+)
+
+import requests
+from PIL import Image
+from io import BytesIO
+import re
+
+
+def image_parser(image_file, sep=','):
+ out = image_file.split(sep)
+ return out
+
+
+def load_image(image_file):
+ if image_file.startswith("http") or image_file.startswith("https"):
+ response = requests.get(image_file)
+ image = Image.open(BytesIO(response.content)).convert("RGB")
+ else:
+ image = Image.open(image_file).convert("RGB")
+ return image
+
+
+def load_images(image_files):
+ out = []
+ for image_file in image_files:
+ image = load_image(image_file)
+ out.append(image)
+ return out
+
+
+def init_dialoggen_model(model_path, model_base=None, load_4bit=False):
+ model_name = get_model_name_from_path(model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
+ model_path, model_base, model_name, llava_type_model=True, load_4bit=load_4bit)
+ return {"tokenizer": tokenizer,
+ "model": model,
+ "image_processor": image_processor}
+
+
+def eval_model(models,
+ query='详细描述一下这张图片',
+ image_file=None,
+ sep=',',
+ temperature=0.2,
+ top_p=None,
+ num_beams=1,
+ max_new_tokens=512,
+ return_history=False,
+ history=None,
+ skip_special=False
+ ):
+ # Model
+ disable_torch_init()
+
+ qs = query
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
+ if IMAGE_PLACEHOLDER in qs:
+ if models["model"].config.mm_use_im_start_end:
+ qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
+ else:
+ qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
+ else:
+ if models["model"].config.mm_use_im_start_end:
+ qs = image_token_se + "\n" + qs
+ else:
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
+
+ if not history:
+ conv = conv_templates['llava_v1'].copy()
+ else:
+ conv = history
+
+ if skip_special:
+ conv.append_message(conv.roles[0], query)
+ else:
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ if image_file is not None:
+ image_files = image_parser(image_file, sep=sep)
+ images = load_images(image_files)
+ image_sizes = [x.size for x in images]
+ images_tensor = process_images(
+ images,
+ models["image_processor"],
+ models["model"].config
+ ).to(models["model"].device, dtype=torch.float16)
+ else:
+ # fomatted input as training data
+ image_sizes = [(1024, 1024)]
+ images_tensor = torch.zeros(1, 5, 3, models["image_processor"].crop_size["height"], models["image_processor"].crop_size["width"])
+ images_tensor = images_tensor.to(models["model"].device, dtype=torch.float16)
+
+ input_ids = (
+ tokenizer_image_token(prompt, models["tokenizer"], IMAGE_TOKEN_INDEX, return_tensors="pt")
+ .unsqueeze(0)
+ .cuda()
+ )
+ with torch.inference_mode():
+ output_ids = models["model"].generate(
+ input_ids,
+ images=images_tensor,
+ image_sizes=image_sizes,
+ do_sample=True if temperature > 0 else False,
+ temperature=temperature,
+ top_p=top_p,
+ num_beams=num_beams,
+ max_new_tokens=max_new_tokens,
+ use_cache=True,
+ )
+
+ outputs = models["tokenizer"].batch_decode(output_ids, skip_special_tokens=True)[0].strip()
+ if return_history:
+ return outputs, conv
+ return outputs
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model_path', type=str, default='/apdcephfs_data_cq5_1/share_300167803/dengxinchi/project/LLaVA/checkpoints/exps/caption/caption_v3_prolr2-5_lr2-5_v5_checkpoint-50000_merge_lora')
+ parser.add_argument('--mode', choices=['caption_zh','caption_en','insert_content'], default="caption_zh")
+ parser.add_argument('--content', type=str, default=None)
+ parser.add_argument('--input_file', type=str, default=None) # 'images/demo.csv'
+ parser.add_argument('--output_file', type=str, default=None) # 'images/demo_res.csv'
+ parser.add_argument('--image_file', type=str, default='images/demo1.jpeg') # 'images/demo1.jpeg'
+ args = parser.parse_args()
+
+ if args.mode == 'caption_zh':
+ query = "描述这张图片"
+ elif args.mode == 'caption_en':
+ query = 'Please describe the content of this image'
+ elif args.mode == 'insert_content':
+ assert args.content is not None
+ query = f'根据提示词“{args.content}”,描述这张图片'
+
+
+ models = init_dialoggen_model(args.model_path)
+
+ if args.input_file != None:
+ df = pd.read_csv(args.input_file)
+ text_zh = []
+ for i in tqdm.tqdm(range(len(df))):
+ img_path = df.loc[i]["img_path"]
+ res = eval_model(models,
+ query=query,
+ image_file=img_path,
+ )
+ text_zh.append(res)
+ df["text_zh"] = text_zh
+ df.to_csv(args.output_file, index=False, encoding='utf-8-sig')
+ else:
+ res = eval_model(models,
+ query=query,
+ image_file=args.image_file,
+ )
+ print(res)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/dialoggen_demo.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/dialoggen_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..d02f9bb5d2e97ad5255ed8c6010a171155b5f70c
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/dialoggen_demo.py
@@ -0,0 +1,189 @@
+import argparse
+import torch
+import sys
+import os
+# 添加当前命令行运行的目录到 sys.path
+sys.path.append(os.getcwd()+"/mllm")
+
+
+from llava.constants import (
+ IMAGE_TOKEN_INDEX,
+ DEFAULT_IMAGE_TOKEN,
+ DEFAULT_IM_START_TOKEN,
+ DEFAULT_IM_END_TOKEN,
+ IMAGE_PLACEHOLDER,
+)
+from llava.conversation import conv_templates, SeparatorStyle
+from llava.model.builder import load_pretrained_model
+from llava.utils import disable_torch_init
+from llava.mm_utils import (
+ process_images,
+ tokenizer_image_token,
+ get_model_name_from_path,
+)
+
+import requests
+from PIL import Image
+from io import BytesIO
+import re
+
+
+def image_parser(image_file, sep=','):
+ out = image_file.split(sep)
+ return out
+
+
+def load_image(image_file):
+ if image_file.startswith("http") or image_file.startswith("https"):
+ response = requests.get(image_file)
+ image = Image.open(BytesIO(response.content)).convert("RGB")
+ else:
+ image = Image.open(image_file).convert("RGB")
+ return image
+
+
+def load_images(image_files):
+ out = []
+ for image_file in image_files:
+ image = load_image(image_file)
+ out.append(image)
+ return out
+
+
+def init_dialoggen_model(model_path, model_base=None, load_4bit=False):
+ model_name = get_model_name_from_path(model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
+ model_path, model_base, model_name, llava_type_model=True, load_4bit=load_4bit)
+ return {"tokenizer": tokenizer,
+ "model": model,
+ "image_processor": image_processor}
+
+
+def eval_model(models,
+ query='详细描述一下这张图片',
+ image_file=None,
+ sep=',',
+ temperature=0.2,
+ top_p=None,
+ num_beams=1,
+ max_new_tokens=512,
+ return_history=False,
+ history=None,
+ skip_special=False
+ ):
+ # Model
+ disable_torch_init()
+
+ qs = query
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
+ if IMAGE_PLACEHOLDER in qs:
+ if models["model"].config.mm_use_im_start_end:
+ qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
+ else:
+ qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
+ else:
+ if models["model"].config.mm_use_im_start_end:
+ qs = image_token_se + "\n" + qs
+ else:
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
+
+ if not history:
+ conv = conv_templates['llava_v1'].copy()
+ else:
+ conv = history
+
+ if skip_special:
+ conv.append_message(conv.roles[0], query)
+ else:
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ if image_file is not None:
+ image_files = image_parser(image_file, sep=sep)
+ images = load_images(image_files)
+ image_sizes = [x.size for x in images]
+ images_tensor = process_images(
+ images,
+ models["image_processor"],
+ models["model"].config
+ ).to(models["model"].device, dtype=torch.float16)
+ else:
+ # fomatted input as training data
+ image_sizes = [(1024, 1024)]
+ images_tensor = torch.zeros(1, 5, 3, models["image_processor"].crop_size["height"], models["image_processor"].crop_size["width"])
+ images_tensor = images_tensor.to(models["model"].device, dtype=torch.float16)
+
+ input_ids = (
+ tokenizer_image_token(prompt, models["tokenizer"], IMAGE_TOKEN_INDEX, return_tensors="pt")
+ .unsqueeze(0)
+ .cuda()
+ )
+ with torch.inference_mode():
+ output_ids = models["model"].generate(
+ input_ids,
+ images=images_tensor,
+ image_sizes=image_sizes,
+ do_sample=True if temperature > 0 else False,
+ temperature=temperature,
+ top_p=top_p,
+ num_beams=num_beams,
+ max_new_tokens=max_new_tokens,
+ use_cache=True,
+ )
+
+ outputs = models["tokenizer"].batch_decode(output_ids, skip_special_tokens=True)[0].strip()
+ if return_history:
+ return outputs, conv
+ return outputs
+
+
+def remove_prefix(text):
+ if text.startswith("<画图>"):
+ return text[len("<画图>"):], True
+ elif text.startswith("对不起"):
+ # 拒绝画图
+ return "", False
+ else:
+ return text, True
+
+
+class DialogGen(object):
+ def __init__(self, model_path, load_4bit=False):
+ self.models = init_dialoggen_model(model_path, load_4bit=load_4bit)
+ self.query_template = "请先判断用户的意图,若为画图则在输出前加入<画图>:{}"
+
+ def __call__(self, prompt, return_history=False, history=None, skip_special=False):
+ enhanced_prompt = eval_model(
+ models=self.models,
+ query=self.query_template.format(prompt),
+ image_file=None,
+ return_history=return_history,
+ history=history,
+ skip_special=skip_special
+ )
+ if return_history:
+ return enhanced_prompt
+
+ enhanced_prompt, compliance = remove_prefix(enhanced_prompt)
+ if not compliance:
+ return False, ""
+ return True, enhanced_prompt
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model_path', type=str, default='./ckpts/dialoggen')
+ parser.add_argument('--prompt', type=str, default='画一只小猫')
+ parser.add_argument('--image_file', type=str, default=None) # 'images/demo1.jpeg'
+ args = parser.parse_args()
+
+ query = f"请先判断用户的意图,若为画图则在输出前加入<画图>:{args.prompt}"
+
+ models = init_dialoggen_model(args.model_path)
+
+ res = eval_model(models,
+ query=query,
+ image_file=args.image_file,
+ )
+ print(res)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/__init__.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d1f016db1028101d45ba7d68cb3f0bcb558c2bb
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/__init__.py
@@ -0,0 +1 @@
+from .model import LlavaLlamaForCausalLM
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/constants.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..374be090510b302de9882d880c755787a8eafe11
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/constants.py
@@ -0,0 +1,13 @@
+CONTROLLER_HEART_BEAT_EXPIRATION = 30
+WORKER_HEART_BEAT_INTERVAL = 15
+
+LOGDIR = "."
+
+# Model Constants
+IGNORE_INDEX = -100
+IMAGE_TOKEN_INDEX = -200
+DEFAULT_IMAGE_TOKEN = ""
+DEFAULT_IMAGE_PATCH_TOKEN = ""
+DEFAULT_IM_START_TOKEN = ""
+DEFAULT_IM_END_TOKEN = ""
+IMAGE_PLACEHOLDER = ""
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/conversation.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..00c56867dd1fd88094df9556f3d1c57e71a7ada8
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/conversation.py
@@ -0,0 +1,396 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple
+import base64
+from io import BytesIO
+from PIL import Image
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+ MPT = auto()
+ PLAIN = auto()
+ LLAMA_2 = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ skip_next: bool = False
+
+ def get_prompt(self):
+ messages = self.messages
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
+ messages = self.messages.copy()
+ init_role, init_msg = messages[0].copy()
+ init_msg = init_msg[0].replace("", "").strip()
+ if 'mmtag' in self.version:
+ messages[0] = (init_role, init_msg)
+ messages.insert(0, (self.roles[0], ""))
+ messages.insert(1, (self.roles[1], "Received."))
+ else:
+ messages[0] = (init_role, "\n" + init_msg)
+
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
+ wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
+ ret = ""
+
+ for i, (role, message) in enumerate(messages):
+ if i == 0:
+ assert message, "first message should not be none"
+ assert role == self.roles[0], "first message should come from user"
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i == 0: message = wrap_sys(self.system) + message
+ if i % 2 == 0:
+ message = wrap_inst(message)
+ ret += self.sep + message
+ else:
+ ret += " " + message + " " + self.sep2
+ else:
+ ret += ""
+ ret = ret.lstrip(self.sep)
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = self.system
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return ret
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
+ if image_process_mode == "Pad":
+ def expand2square(pil_img, background_color=(122, 116, 104)):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image)
+ elif image_process_mode in ["Default", "Crop"]:
+ pass
+ elif image_process_mode == "Resize":
+ image = image.resize((336, 336))
+ else:
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
+ if max(image.size) > max_len:
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ if return_pil:
+ return image
+ else:
+ buffered = BytesIO()
+ image.save(buffered, format=image_format)
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ return img_b64_str
+
+ def get_images(self, return_pil=False):
+ images = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ msg, image, image_process_mode = msg
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
+ images.append(image)
+ return images
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ msg, image, image_process_mode = msg
+ img_b64_str = self.process_image(
+ image, "Default", return_pil=False,
+ image_format='JPEG')
+ img_str = f''
+ msg = img_str + msg.replace('', '').strip()
+ ret.append([msg, None])
+ else:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ version=self.version)
+
+ def dict(self):
+ if len(self.get_images()) > 0:
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+
+
+conv_vicuna_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
+ ("Assistant",
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
+ "renewable and non-renewable energy sources:\n"
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
+ "energy sources are finite and will eventually run out.\n"
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
+ "and other negative effects.\n"
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
+ "have lower operational costs than non-renewable sources.\n"
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
+ "locations than non-renewable sources.\n"
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_vicuna_v1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_llama_2 = Conversation(
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_llava_llama_2 = Conversation(
+ system="You are a helpful language and vision assistant. "
+ "You are able to understand the visual content that the user provides, "
+ "and assist the user with a variety of tasks using natural language.",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_mpt = Conversation(
+ system="""<|im_start|>system
+A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_llava_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.PLAIN,
+ sep="\n",
+)
+
+conv_llava_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_llava_v0_mmtag = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "The visual content will be provided with the following format: visual content.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+ version="v0_mmtag",
+)
+
+conv_llava_v1 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_llava_v1_mmtag = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "The visual content will be provided with the following format: visual content.",
+ roles=("USER", "ASSISTANT"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+ version="v1_mmtag",
+)
+
+conv_mistral_instruct = Conversation(
+ system="",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_chatml_direct = Conversation(
+ system="""<|im_start|>system
+Answer the questions.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+default_conversation = conv_vicuna_v1
+conv_templates = {
+ "default": conv_vicuna_v0,
+ "v0": conv_vicuna_v0,
+ "v1": conv_vicuna_v1,
+ "vicuna_v1": conv_vicuna_v1,
+ "llama_2": conv_llama_2,
+ "mistral_instruct": conv_mistral_instruct,
+ "chatml_direct": conv_chatml_direct,
+ "mistral_direct": conv_chatml_direct,
+
+ "plain": conv_llava_plain,
+ "v0_plain": conv_llava_plain,
+ "llava_v0": conv_llava_v0,
+ "v0_mmtag": conv_llava_v0_mmtag,
+ "llava_v1": conv_llava_v1,
+ "v1_mmtag": conv_llava_v1_mmtag,
+ "llava_llama_2": conv_llava_llama_2,
+
+ "mpt": conv_mpt,
+}
+
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/mm_utils.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/mm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9717a81f8e481a452bbd99f4aa0baad95d0306df
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/mm_utils.py
@@ -0,0 +1,247 @@
+from PIL import Image
+from io import BytesIO
+import base64
+import torch
+import math
+import ast
+
+from transformers import StoppingCriteria
+from llava.constants import IMAGE_TOKEN_INDEX
+
+
+def select_best_resolution(original_size, possible_resolutions):
+ """
+ Selects the best resolution from a list of possible resolutions based on the original size.
+
+ Args:
+ original_size (tuple): The original size of the image in the format (width, height).
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
+
+ Returns:
+ tuple: The best fit resolution in the format (width, height).
+ """
+ original_width, original_height = original_size
+ best_fit = None
+ max_effective_resolution = 0
+ min_wasted_resolution = float('inf')
+
+ for width, height in possible_resolutions:
+ scale = min(width / original_width, height / original_height)
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
+ wasted_resolution = (width * height) - effective_resolution
+
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
+ max_effective_resolution = effective_resolution
+ min_wasted_resolution = wasted_resolution
+ best_fit = (width, height)
+
+ return best_fit
+
+
+def resize_and_pad_image(image, target_resolution):
+ """
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
+
+ Args:
+ image (PIL.Image.Image): The input image.
+ target_resolution (tuple): The target resolution (width, height) of the image.
+
+ Returns:
+ PIL.Image.Image: The resized and padded image.
+ """
+ original_width, original_height = image.size
+ target_width, target_height = target_resolution
+
+ scale_w = target_width / original_width
+ scale_h = target_height / original_height
+
+ if scale_w < scale_h:
+ new_width = target_width
+ new_height = min(math.ceil(original_height * scale_w), target_height)
+ else:
+ new_height = target_height
+ new_width = min(math.ceil(original_width * scale_h), target_width)
+
+ # Resize the image
+ resized_image = image.resize((new_width, new_height))
+
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
+ paste_x = (target_width - new_width) // 2
+ paste_y = (target_height - new_height) // 2
+ new_image.paste(resized_image, (paste_x, paste_y))
+
+ return new_image
+
+
+def divide_to_patches(image, patch_size):
+ """
+ Divides an image into patches of a specified size.
+
+ Args:
+ image (PIL.Image.Image): The input image.
+ patch_size (int): The size of each patch.
+
+ Returns:
+ list: A list of PIL.Image.Image objects representing the patches.
+ """
+ patches = []
+ width, height = image.size
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ box = (j, i, j + patch_size, i + patch_size)
+ patch = image.crop(box)
+ patches.append(patch)
+
+ return patches
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (tuple): The size of the input image in the format (width, height).
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
+ patch_size (int): The size of each image patch.
+
+ Returns:
+ tuple: The shape of the image patch grid in the format (width, height).
+ """
+ if type(grid_pinpoints) is list:
+ possible_resolutions = grid_pinpoints
+ else:
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
+ width, height = select_best_resolution(image_size, possible_resolutions)
+ return width // patch_size, height // patch_size
+
+
+def process_anyres_image(image, processor, grid_pinpoints):
+ """
+ Process an image with variable resolutions.
+
+ Args:
+ image (PIL.Image.Image): The input image to be processed.
+ processor: The image processor object.
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
+
+ Returns:
+ torch.Tensor: A tensor containing the processed image patches.
+ """
+ if type(grid_pinpoints) is list:
+ possible_resolutions = grid_pinpoints
+ else:
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
+ image_padded = resize_and_pad_image(image, best_resolution)
+
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
+
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
+
+ image_patches = [image_original_resize] + patches
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
+ for image_patch in image_patches]
+ return torch.stack(image_patches, dim=0)
+
+
+def load_image_from_base64(image):
+ return Image.open(BytesIO(base64.b64decode(image)))
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def process_images(images, image_processor, model_cfg):
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
+ new_images = []
+ if image_aspect_ratio == 'pad':
+ for image in images:
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ new_images.append(image)
+ elif image_aspect_ratio == "anyres":
+ for image in images:
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
+ new_images.append(image)
+ else:
+ return image_processor(images, return_tensors='pt')['pixel_values']
+ if all(x.shape == new_images[0].shape for x in new_images):
+ new_images = torch.stack(new_images, dim=0)
+ return new_images
+
+
+def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
+
+ def insert_separator(X, sep):
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
+
+ input_ids = []
+ offset = 0
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
+ offset = 1
+ input_ids.append(prompt_chunks[0][0])
+
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
+ input_ids.extend(x[offset:])
+
+ if return_tensors is not None:
+ if return_tensors == 'pt':
+ return torch.tensor(input_ids, dtype=torch.long)
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
+ return input_ids
+
+
+def get_model_name_from_path(model_path):
+ model_path = model_path.strip("/")
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith('checkpoint-'):
+ return model_paths[-2] + "_" + model_paths[-1]
+ else:
+ return model_paths[-1]
+
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.keyword_ids = []
+ self.max_keyword_len = 0
+ for keyword in keywords:
+ cur_keyword_ids = tokenizer(keyword).input_ids
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
+ cur_keyword_ids = cur_keyword_ids[1:]
+ if len(cur_keyword_ids) > self.max_keyword_len:
+ self.max_keyword_len = len(cur_keyword_ids)
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
+ self.tokenizer = tokenizer
+ self.start_len = input_ids.shape[1]
+
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
+ for keyword_id in self.keyword_ids:
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
+ if torch.equal(truncated_output_ids, keyword_id):
+ return True
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
+ for keyword in self.keywords:
+ if keyword in outputs:
+ return True
+ return False
+
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ outputs = []
+ for i in range(output_ids.shape[0]):
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
+ return all(outputs)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/__init__.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbd91789f0cde61dd13a7f9a5f7a69488ad07279
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/__init__.py
@@ -0,0 +1,6 @@
+try:
+ from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
+ from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
+ from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
+except:
+ pass
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/apply_delta.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/apply_delta.py
new file mode 100644
index 0000000000000000000000000000000000000000..666dd9691bde7d54ddf2871e311d6f621e29f099
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/apply_delta.py
@@ -0,0 +1,48 @@
+"""
+Usage:
+python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
+"""
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from llava import LlavaLlamaForCausalLM
+
+
+def apply_delta(base_model_path, target_model_path, delta_path):
+ print("Loading base model")
+ base = AutoModelForCausalLM.from_pretrained(
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Loading delta")
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
+
+ print("Applying delta")
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
+ if name not in base.state_dict():
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
+ continue
+ if param.data.shape == base.state_dict()[name].shape:
+ param.data += base.state_dict()[name]
+ else:
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
+ bparam = base.state_dict()[name]
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
+
+ print("Saving target model")
+ delta.save_pretrained(target_model_path)
+ delta_tokenizer.save_pretrained(target_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/builder.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..263d5d1dea46912d7f8eb767007e89c4471eacf4
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/builder.py
@@ -0,0 +1,166 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import warnings
+import shutil
+
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
+import torch
+from llava.model import *
+from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+
+
+def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, llava_type_model=True, **kwargs):
+ kwargs = {"device_map": device_map, **kwargs}
+
+ if device != "cuda":
+ kwargs['device_map'] = {"": device}
+ if load_8bit:
+ kwargs['load_in_8bit'] = True
+ elif load_4bit:
+ kwargs['load_in_4bit'] = True
+ kwargs['quantization_config'] = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type='nf4'
+ )
+ else:
+ kwargs['torch_dtype'] = torch.float16
+
+ if use_flash_attn:
+ kwargs['attn_implementation'] = 'flash_attention_2'
+
+ if 'llava' in model_name.lower():
+ # Load LLaVA model
+ if 'lora' in model_name.lower() and model_base is None:
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
+ if 'lora' in model_name.lower() and model_base is not None:
+ from llava.model.language_model.llava_llama import LlavaConfig
+ lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ print('Loading LLaVA from base model...')
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
+ if model.lm_head.weight.shape[0] != token_num:
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
+
+ print('Loading additional LLaVA weights...')
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
+ else:
+ # this is probably from HF Hub
+ from huggingface_hub import hf_hub_download
+ def load_from_hf(repo_id, filename, subfolder=None):
+ cache_file = hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ subfolder=subfolder)
+ return torch.load(cache_file, map_location='cpu')
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
+ model.load_state_dict(non_lora_trainables, strict=False)
+
+ from peft import PeftModel
+ print('Loading LoRA weights...')
+ model = PeftModel.from_pretrained(model, model_path)
+ print('Merging LoRA weights...')
+ model = model.merge_and_unload()
+ print('Model is loaded...')
+ elif model_base is not None:
+ # this may be mm projector only
+ print('Loading LLaVA from base model...')
+ if 'mpt' in model_name.lower():
+ if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
+ shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
+ cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
+ model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
+
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
+ model.load_state_dict(mm_projector_weights, strict=False)
+ else:
+ if 'mpt' in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
+ model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
+ elif 'mistral' in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ model = LlavaMistralForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ **kwargs
+ )
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ model = LlavaLlamaForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ **kwargs
+ )
+ else:
+ # Load language model
+ if model_base is not None:
+ # PEFT model
+ from peft import PeftModel
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
+ print(f"Loading LoRA weights from {model_path}")
+ model = PeftModel.from_pretrained(model, model_path)
+ print(f"Merging weights")
+ model = model.merge_and_unload()
+ print('Convert to FP16...')
+ model.to(torch.float16)
+ else:
+ use_fast = False
+ if 'mpt' in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
+
+ image_processor = None
+
+ if llava_type_model:
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
+ if mm_use_im_patch_token:
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ if mm_use_im_start_end:
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ model.resize_token_embeddings(len(tokenizer))
+
+ vision_tower = model.get_vision_tower()
+ if not vision_tower.is_loaded:
+ vision_tower.load_model(device_map=device_map)
+ if device_map != 'auto':
+ vision_tower.to(device=device_map, dtype=torch.float16)
+ image_processor = vision_tower.image_processor
+
+ if hasattr(model.config, "max_sequence_length"):
+ context_len = model.config.max_sequence_length
+ else:
+ context_len = 2048
+
+ return tokenizer, model, image_processor, context_len
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/consolidate.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/consolidate.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e324210e229eeba23b75791bba82df7c6e639eb
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/consolidate.py
@@ -0,0 +1,29 @@
+"""
+Usage:
+python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
+"""
+import argparse
+
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from llava.model import *
+from llava.model.utils import auto_upgrade
+
+
+def consolidate_ckpt(src_path, dst_path):
+ print("Loading model")
+ auto_upgrade(src_path)
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
+ src_model.save_pretrained(dst_path)
+ src_tokenizer.save_pretrained(dst_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--src", type=str, required=True)
+ parser.add_argument("--dst", type=str, required=True)
+
+ args = parser.parse_args()
+
+ consolidate_ckpt(args.src, args.dst)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/language_model/llava_llama.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/language_model/llava_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..069d0d1c10da42f5d278598e8534f166d1f9f5ff
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/language_model/llava_llama.py
@@ -0,0 +1,158 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from transformers import AutoConfig, AutoModelForCausalLM, \
+ LlamaConfig, LlamaModel, LlamaForCausalLM
+
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.generation.utils import GenerateOutput
+
+from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+
+
+class LlavaConfig(LlamaConfig):
+ model_type = "llava_llama"
+
+
+class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
+ config_class = LlavaConfig
+
+ def __init__(self, config: LlamaConfig):
+ super(LlavaLlamaModel, self).__init__(config)
+
+
+class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaConfig
+
+ def __init__(self, config):
+ super(LlamaForCausalLM, self).__init__(config)
+ self.model = LlavaLlamaModel(config)
+ self.pretraining_tp = config.pretraining_tp
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[List[List[int]]] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+ if inputs_embeds is None:
+ (
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ inputs_embeds,
+ labels
+ ) = self.prepare_inputs_labels_for_multimodal(
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ labels,
+ images,
+ image_sizes
+ )
+
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+ if images is not None:
+ (
+ inputs,
+ position_ids,
+ attention_mask,
+ _,
+ inputs_embeds,
+ _
+ ) = self.prepare_inputs_labels_for_multimodal(
+ inputs,
+ position_ids,
+ attention_mask,
+ None,
+ None,
+ images,
+ image_sizes=image_sizes
+ )
+ else:
+ inputs_embeds = self.get_model().embed_tokens(inputs)
+
+ return super().generate(
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ **kwargs
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
+ inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ image_sizes = kwargs.pop("image_sizes", None)
+ inputs = super().prepare_inputs_for_generation(
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
+ )
+ if images is not None:
+ inputs['images'] = images
+ if image_sizes is not None:
+ inputs['image_sizes'] = image_sizes
+ return inputs
+
+AutoConfig.register("llava_llama", LlavaConfig)
+AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/language_model/llava_mistral.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/language_model/llava_mistral.py
new file mode 100644
index 0000000000000000000000000000000000000000..0def682ea3c497e36aa85f1c53eb2cfab6e2fb87
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/language_model/llava_mistral.py
@@ -0,0 +1,158 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+from transformers import AutoConfig, AutoModelForCausalLM, \
+ MistralConfig, MistralModel, MistralForCausalLM
+
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.generation.utils import GenerateOutput
+
+from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+
+
+class LlavaMistralConfig(MistralConfig):
+ model_type = "llava_mistral"
+
+
+class LlavaMistralModel(LlavaMetaModel, MistralModel):
+ config_class = LlavaMistralConfig
+
+ def __init__(self, config: MistralConfig):
+ super(LlavaMistralModel, self).__init__(config)
+
+
+class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaMistralConfig
+
+ def __init__(self, config):
+ super(MistralForCausalLM, self).__init__(config)
+ self.model = LlavaMistralModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[List[List[int]]] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+ if inputs_embeds is None:
+ (
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ inputs_embeds,
+ labels
+ ) = self.prepare_inputs_labels_for_multimodal(
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ labels,
+ images,
+ image_sizes
+ )
+
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+ if images is not None:
+ (
+ inputs,
+ position_ids,
+ attention_mask,
+ _,
+ inputs_embeds,
+ _
+ ) = self.prepare_inputs_labels_for_multimodal(
+ inputs,
+ position_ids,
+ attention_mask,
+ None,
+ None,
+ images,
+ image_sizes=image_sizes
+ )
+ else:
+ inputs_embeds = self.get_model().embed_tokens(inputs)
+
+ return super().generate(
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ **kwargs
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
+ inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ image_sizes = kwargs.pop("image_sizes", None)
+ inputs = super().prepare_inputs_for_generation(
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
+ )
+ if images is not None:
+ inputs['images'] = images
+ if image_sizes is not None:
+ inputs['image_sizes'] = image_sizes
+ return inputs
+
+AutoConfig.register("llava_mistral", LlavaMistralConfig)
+AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/language_model/llava_mpt.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/language_model/llava_mpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..02e5237ece031af23fcd76b5b4e0d9b0bc5f55cc
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/language_model/llava_mpt.py
@@ -0,0 +1,97 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Tuple
+
+import torch
+
+from transformers import AutoConfig, AutoModelForCausalLM, \
+ MptConfig, MptForCausalLM, MptModel
+from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+
+
+class LlavaMptConfig(MptConfig):
+ model_type = "llava_mpt"
+
+
+class LlavaMptModel(LlavaMetaModel, MptModel):
+ config_class = LlavaMptConfig
+
+ def __init__(self, config: MptConfig):
+ config.hidden_size = config.d_model
+ super(LlavaMptModel, self).__init__(config)
+
+ def embed_tokens(self, x):
+ return self.wte(x)
+
+
+class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaMptConfig
+ supports_gradient_checkpointing = True
+
+ def __init__(self, config):
+ super(MptForCausalLM, self).__init__(config)
+
+ self.transformer = LlavaMptModel(config)
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.transformer
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, LlavaMptModel):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ images=None):
+
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
+
+ return super().forward(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ _inputs = super().prepare_inputs_for_generation(
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
+ )
+ _inputs['images'] = images
+ return _inputs
+
+
+AutoConfig.register("llava_mpt", LlavaMptConfig)
+AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/llava_arch.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/llava_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b299d3c416a0f5ffea3d03d7be5a32b77319533
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/llava_arch.py
@@ -0,0 +1,368 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from abc import ABC, abstractmethod
+
+import torch
+import torch.nn as nn
+
+from .multimodal_encoder.builder import build_vision_tower
+from .multimodal_projector.builder import build_vision_projector
+
+from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+
+from llava.mm_utils import get_anyres_image_grid_shape
+
+
+class LlavaMetaModel:
+
+ def __init__(self, config):
+ super(LlavaMetaModel, self).__init__(config)
+
+ if hasattr(config, "mm_vision_tower"):
+ self.vision_tower = build_vision_tower(config, delay_load=True)
+ self.mm_projector = build_vision_projector(config)
+
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
+ self.image_newline = nn.Parameter(
+ torch.empty(config.hidden_size, dtype=self.dtype)
+ )
+
+ def get_vision_tower(self):
+ vision_tower = getattr(self, 'vision_tower', None)
+ if type(vision_tower) is list:
+ vision_tower = vision_tower[0]
+ return vision_tower
+
+ def initialize_vision_modules(self, model_args, fsdp=None):
+ vision_tower = model_args.vision_tower
+ mm_vision_select_layer = model_args.mm_vision_select_layer
+ mm_vision_select_feature = model_args.mm_vision_select_feature
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
+ mm_patch_merge_type = model_args.mm_patch_merge_type
+
+ self.config.mm_vision_tower = vision_tower
+
+ if self.get_vision_tower() is None:
+ vision_tower = build_vision_tower(model_args)
+
+ if fsdp is not None and len(fsdp) > 0:
+ self.vision_tower = [vision_tower]
+ else:
+ self.vision_tower = vision_tower
+ else:
+ if fsdp is not None and len(fsdp) > 0:
+ vision_tower = self.vision_tower[0]
+ else:
+ vision_tower = self.vision_tower
+ vision_tower.load_model()
+
+ self.config.use_mm_proj = True
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
+ self.config.mm_hidden_size = vision_tower.hidden_size
+ self.config.mm_vision_select_layer = mm_vision_select_layer
+ self.config.mm_vision_select_feature = mm_vision_select_feature
+ self.config.mm_patch_merge_type = mm_patch_merge_type
+
+ if getattr(self, 'mm_projector', None) is None:
+ self.mm_projector = build_vision_projector(self.config)
+
+ if 'unpad' in mm_patch_merge_type:
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
+ self.image_newline = nn.Parameter(
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
+ )
+ else:
+ # In case it is frozen by LoRA
+ for p in self.mm_projector.parameters():
+ p.requires_grad = True
+
+ if pretrain_mm_mlp_adapter is not None:
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
+ def get_w(weights, keyword):
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
+
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
+
+
+def unpad_image(tensor, original_size):
+ """
+ Unpads a PyTorch tensor of a padded and resized image.
+
+ Args:
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
+ original_size (tuple): The original size of the image (height, width).
+
+ Returns:
+ torch.Tensor: The unpadded image tensor.
+ """
+ original_width, original_height = original_size
+ current_height, current_width = tensor.shape[1:]
+
+ original_aspect_ratio = original_width / original_height
+ current_aspect_ratio = current_width / current_height
+
+ if original_aspect_ratio > current_aspect_ratio:
+ scale_factor = current_width / original_width
+ new_height = int(original_height * scale_factor)
+ padding = (current_height - new_height) // 2
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
+ else:
+ scale_factor = current_height / original_height
+ new_width = int(original_width * scale_factor)
+ padding = (current_width - new_width) // 2
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
+
+ return unpadded_tensor
+
+
+class LlavaMetaForCausalLM(ABC):
+
+ @abstractmethod
+ def get_model(self):
+ pass
+
+ def get_vision_tower(self):
+ return self.get_model().get_vision_tower()
+
+ def encode_images(self, images):
+ image_features = self.get_model().get_vision_tower()(images)
+ image_features = self.get_model().mm_projector(image_features)
+ return image_features
+
+ def prepare_inputs_labels_for_multimodal(
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
+ images, image_sizes=None
+ ):
+ vision_tower = self.get_vision_tower()
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
+
+ if type(images) is list or images.ndim == 5:
+ if type(images) is list:
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
+ concat_images = torch.cat([image for image in images], dim=0)
+ image_features = self.encode_images(concat_images)
+ split_sizes = [image.shape[0] for image in images]
+ image_features = torch.split(image_features, split_sizes, dim=0)
+ mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
+ if mm_patch_merge_type == 'flat':
+ image_features = [x.flatten(0, 1) for x in image_features]
+ elif mm_patch_merge_type.startswith('spatial'):
+ new_image_features = []
+ for image_idx, image_feature in enumerate(image_features):
+ if image_feature.shape[0] > 1:
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+ height = width = self.get_vision_tower().num_patches_per_side
+ assert height * width == base_image_feature.shape[0]
+ if image_aspect_ratio == 'anyres':
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
+ else:
+ raise NotImplementedError
+ if 'unpad' in mm_patch_merge_type:
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ image_feature = torch.cat((
+ image_feature,
+ self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
+ ), dim=-1)
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ else:
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
+ image_feature = image_feature.flatten(0, 3)
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+ else:
+ image_feature = image_feature[0]
+ if 'unpad' in mm_patch_merge_type:
+ image_feature = torch.cat((
+ image_feature,
+ self.model.image_newline[None].to(image_feature.device)
+ ), dim=0)
+ new_image_features.append(image_feature)
+ image_features = new_image_features
+ else:
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
+ else:
+ image_features = self.encode_images(images)
+
+ # TODO: image start / end is not implemented here to support pretraining.
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
+ raise NotImplementedError
+
+ # Let's just add dummy tensors if they do not exist,
+ # it is a headache to deal with None all the time.
+ # But it is not ideal, and if you have a better idea,
+ # please open an issue / submit a PR, thanks.
+ _labels = labels
+ _position_ids = position_ids
+ _attention_mask = attention_mask
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
+ else:
+ attention_mask = attention_mask.bool()
+ if position_ids is None:
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
+ if labels is None:
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
+
+ # remove the padding using attention_mask -- FIXME
+ _input_ids = input_ids
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
+
+ new_input_embeds = []
+ new_labels = []
+ cur_image_idx = 0
+ for batch_idx, cur_input_ids in enumerate(input_ids):
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
+ if num_images == 0:
+ cur_image_features = image_features[cur_image_idx]
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
+ new_input_embeds.append(cur_input_embeds)
+ new_labels.append(labels[batch_idx])
+ cur_image_idx += 1
+ continue
+
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
+ cur_input_ids_noim = []
+ cur_labels = labels[batch_idx]
+ cur_labels_noim = []
+ for i in range(len(image_token_indices) - 1):
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
+ cur_new_input_embeds = []
+ cur_new_labels = []
+
+ for i in range(num_images + 1):
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
+ cur_new_labels.append(cur_labels_noim[i])
+ if i < num_images:
+ cur_image_features = image_features[cur_image_idx]
+ cur_image_idx += 1
+ cur_new_input_embeds.append(cur_image_features)
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
+
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
+
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
+ cur_new_labels = torch.cat(cur_new_labels)
+
+ new_input_embeds.append(cur_new_input_embeds)
+ new_labels.append(cur_new_labels)
+
+ # Truncate sequences to max length as image embeddings can make the sequence longer
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
+ if tokenizer_model_max_length is not None:
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
+
+ # Combine them
+ max_len = max(x.shape[0] for x in new_input_embeds)
+ batch_size = len(new_input_embeds)
+
+ new_input_embeds_padded = []
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
+
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
+ cur_len = cur_new_embed.shape[0]
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
+ new_input_embeds_padded.append(torch.cat((
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
+ cur_new_embed
+ ), dim=0))
+ if cur_len > 0:
+ new_labels_padded[i, -cur_len:] = cur_new_labels
+ attention_mask[i, -cur_len:] = True
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
+ else:
+ new_input_embeds_padded.append(torch.cat((
+ cur_new_embed,
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
+ ), dim=0))
+ if cur_len > 0:
+ new_labels_padded[i, :cur_len] = cur_new_labels
+ attention_mask[i, :cur_len] = True
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
+
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
+
+ if _labels is None:
+ new_labels = None
+ else:
+ new_labels = new_labels_padded
+
+ if _attention_mask is None:
+ attention_mask = None
+ else:
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
+
+ if _position_ids is None:
+ position_ids = None
+
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
+
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
+ if model_args.mm_use_im_patch_token:
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+
+ if model_args.mm_use_im_start_end:
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = self.get_input_embeddings().weight.data
+ output_embeddings = self.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+ if model_args.tune_mm_mlp_adapter:
+ for p in self.get_input_embeddings().parameters():
+ p.requires_grad = True
+ for p in self.get_output_embeddings().parameters():
+ p.requires_grad = False
+
+ if model_args.pretrain_mm_mlp_adapter:
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
+ assert num_new_tokens == 2
+ if input_embeddings.shape == embed_tokens_weight.shape:
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
+ else:
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
+ elif model_args.mm_use_im_patch_token:
+ if model_args.tune_mm_mlp_adapter:
+ for p in self.get_input_embeddings().parameters():
+ p.requires_grad = False
+ for p in self.get_output_embeddings().parameters():
+ p.requires_grad = False
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/make_delta.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/make_delta.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ae55d59c2c8bab80299272314a41bbeb959d8ed
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/make_delta.py
@@ -0,0 +1,52 @@
+"""
+Usage:
+python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
+"""
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from llava.model.utils import auto_upgrade
+
+
+def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
+ print("Loading base model")
+ base = AutoModelForCausalLM.from_pretrained(
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Loading target model")
+ auto_upgrade(target_model_path)
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Calculating delta")
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
+ if name not in base.state_dict():
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
+ continue
+ if param.data.shape == base.state_dict()[name].shape:
+ param.data -= base.state_dict()[name]
+ else:
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
+ bparam = base.state_dict()[name]
+ param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
+
+ print("Saving delta")
+ if hub_repo_id:
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
+ else:
+ kwargs = {}
+ target.save_pretrained(delta_path, **kwargs)
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+ parser.add_argument("--hub-repo-id", type=str, default=None)
+ args = parser.parse_args()
+
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/multimodal_encoder/builder.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/multimodal_encoder/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e89507c49df413945453959d48b51a71b9031ef7
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/multimodal_encoder/builder.py
@@ -0,0 +1,11 @@
+import os
+from .clip_encoder import CLIPVisionTower
+
+
+def build_vision_tower(vision_tower_cfg, **kwargs):
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
+ is_absolute_path_exists = os.path.exists(vision_tower)
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
+
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/multimodal_encoder/clip_encoder.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/multimodal_encoder/clip_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..89ce2d9e084a2b324dc5e19fb2d2d889f3a60602
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/multimodal_encoder/clip_encoder.py
@@ -0,0 +1,88 @@
+import torch
+import torch.nn as nn
+
+from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
+
+
+class CLIPVisionTower(nn.Module):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.vision_tower_name = vision_tower
+ self.select_layer = args.mm_vision_select_layer
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
+
+ if not delay_load:
+ self.load_model()
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
+ self.load_model()
+ else:
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
+
+ def load_model(self, device_map=None):
+ if self.is_loaded:
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
+ return
+
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
+ self.vision_tower.requires_grad_(False)
+
+ self.is_loaded = True
+
+ def feature_select(self, image_forward_outs):
+ image_features = image_forward_outs.hidden_states[self.select_layer]
+ if self.select_feature == 'patch':
+ image_features = image_features[:, 1:]
+ elif self.select_feature == 'cls_patch':
+ image_features = image_features
+ else:
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
+ return image_features
+
+ @torch.no_grad()
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
+ image_features.append(image_feature)
+ else:
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ return self.vision_tower.dtype
+
+ @property
+ def device(self):
+ return self.vision_tower.device
+
+ @property
+ def config(self):
+ if self.is_loaded:
+ return self.vision_tower.config
+ else:
+ return self.cfg_only
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size
+
+ @property
+ def num_patches_per_side(self):
+ return self.config.image_size // self.config.patch_size
+
+ @property
+ def num_patches(self):
+ return (self.config.image_size // self.config.patch_size) ** 2
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/multimodal_projector/builder.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/multimodal_projector/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..31cd4f48e6055cd6d00a162af30b1c8139e26b57
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/multimodal_projector/builder.py
@@ -0,0 +1,51 @@
+import torch
+import torch.nn as nn
+import re
+
+
+class IdentityMap(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+ @property
+ def config(self):
+ return {"mm_projector_type": 'identity'}
+
+
+class SimpleResBlock(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.pre_norm = nn.LayerNorm(channels)
+
+ self.proj = nn.Sequential(
+ nn.Linear(channels, channels),
+ nn.GELU(),
+ nn.Linear(channels, channels)
+ )
+ def forward(self, x):
+ x = self.pre_norm(x)
+ return x + self.proj(x)
+
+
+def build_vision_projector(config, delay_load=False, **kwargs):
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
+
+ if projector_type == 'linear':
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
+
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
+ if mlp_gelu_match:
+ mlp_depth = int(mlp_gelu_match.group(1))
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
+ for _ in range(1, mlp_depth):
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
+ return nn.Sequential(*modules)
+
+ if projector_type == 'identity':
+ return IdentityMap()
+
+ raise ValueError(f'Unknown projector type: {projector_type}')
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/utils.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2563f89c6cedf5e73508afec8f9979105df9b745
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/model/utils.py
@@ -0,0 +1,20 @@
+from transformers import AutoConfig
+
+
+def auto_upgrade(config):
+ cfg = AutoConfig.from_pretrained(config)
+ if 'llava' in config and 'llava' not in cfg.model_type:
+ assert cfg.model_type == 'llama'
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
+ if confirm.lower() in ["y", "yes"]:
+ print("Upgrading checkpoint...")
+ assert len(cfg.architectures) == 1
+ setattr(cfg.__class__, "model_type", "llava")
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
+ cfg.save_pretrained(config)
+ print("Checkpoint upgraded.")
+ else:
+ print("Checkpoint upgrade aborted.")
+ exit(1)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/__init__.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/cli.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4a52e9eb1a5fd711d2dce0afd452f3b425755f0
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/cli.py
@@ -0,0 +1,141 @@
+import argparse
+import torch
+
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from llava.conversation import conv_templates, SeparatorStyle
+from llava.model.builder import load_pretrained_model
+from llava.utils import disable_torch_init
+from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
+
+from PIL import Image
+
+import requests
+from PIL import Image
+from io import BytesIO
+from transformers import TextStreamer
+
+
+def load_image(image_file):
+ if image_file.startswith('http://') or image_file.startswith('https://'):
+ response = requests.get(image_file)
+ image = Image.open(BytesIO(response.content)).convert('RGB')
+ else:
+ image = Image.open(image_file).convert('RGB')
+ return image
+
+
+def main(args):
+ # Model
+ disable_torch_init()
+
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name,
+ args.load_8bit, args.load_4bit,
+ device=args.device)
+
+ if "llama-2" in model_name.lower():
+ conv_mode = "llava_llama_2"
+ elif "mistral" in model_name.lower():
+ conv_mode = "mistral_instruct"
+ elif "v1.6-34b" in model_name.lower():
+ conv_mode = "chatml_direct"
+ elif "v1" in model_name.lower():
+ conv_mode = "llava_v1"
+ elif "mpt" in model_name.lower():
+ conv_mode = "mpt"
+ else:
+ conv_mode = "llava_v0"
+
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
+ print(
+ '[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode,
+ args.conv_mode,
+ args.conv_mode))
+ else:
+ args.conv_mode = conv_mode
+
+ conv = conv_templates[args.conv_mode].copy()
+ if "mpt" in model_name.lower():
+ roles = ('user', 'assistant')
+ else:
+ roles = conv.roles
+
+ if args.image_file is not None:
+ image = load_image(args.image_file)
+ image_size = image.size
+ # Similar operation in model_worker.py
+ image_tensor = process_images([image], image_processor, model.config)
+ if type(image_tensor) is list:
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
+ else:
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
+ else:
+ image = True
+ image_size = (1024, 1024)
+ image_tensor = torch.zeros(1, 5, 3, 336, 336)
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
+
+
+ while True:
+ try:
+ inp = input(f"{roles[0]}: ")
+ except EOFError:
+ inp = ""
+ if not inp:
+ print("exit...")
+ break
+
+ print(f"{roles[1]}: ", end="")
+
+ if image is not None:
+ # first message
+ if model.config.mm_use_im_start_end:
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
+ else:
+ inp = inp.replace(DEFAULT_IMAGE_TOKEN, '').strip()
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
+ inp = inp.strip()
+ image = None
+
+ conv.append_message(conv.roles[0], inp)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(
+ model.device)
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
+ keywords = [stop_str]
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
+
+ with torch.inference_mode():
+ output_ids = model.generate(
+ input_ids,
+ images=image_tensor,
+ image_sizes=[image_size],
+ do_sample=True if args.temperature > 0 else False,
+ temperature=args.temperature,
+ max_new_tokens=args.max_new_tokens,
+ streamer=streamer,
+ use_cache=True)
+
+ outputs = tokenizer.decode(output_ids[0]).strip()
+ conv.messages[-1][-1] = outputs
+
+ if args.debug:
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--image-file", type=str, default=None)
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ args = parser.parse_args()
+ main(args)
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/controller.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4bf1b4c47ccdb1401b18f8397868ec016d1c43a
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/controller.py
@@ -0,0 +1,298 @@
+"""
+A controller manages distributed workers.
+It sends worker addresses to clients.
+"""
+import argparse
+import asyncio
+import dataclasses
+from enum import Enum, auto
+import json
+import logging
+import time
+from typing import List, Union
+import threading
+
+from fastapi import FastAPI, Request
+from fastapi.responses import StreamingResponse
+import numpy as np
+import requests
+import uvicorn
+
+from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
+from llava.utils import build_logger, server_error_msg
+
+
+logger = build_logger("controller", "controller.log")
+
+
+class DispatchMethod(Enum):
+ LOTTERY = auto()
+ SHORTEST_QUEUE = auto()
+
+ @classmethod
+ def from_str(cls, name):
+ if name == "lottery":
+ return cls.LOTTERY
+ elif name == "shortest_queue":
+ return cls.SHORTEST_QUEUE
+ else:
+ raise ValueError(f"Invalid dispatch method")
+
+
+@dataclasses.dataclass
+class WorkerInfo:
+ model_names: List[str]
+ speed: int
+ queue_length: int
+ check_heart_beat: bool
+ last_heart_beat: str
+
+
+def heart_beat_controller(controller):
+ while True:
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
+ controller.remove_stable_workers_by_expiration()
+
+
+class Controller:
+ def __init__(self, dispatch_method: str):
+ # Dict[str -> WorkerInfo]
+ self.worker_info = {}
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
+
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_controller, args=(self,), daemon=True)
+ self.heart_beat_thread.start()
+
+ logger.info("Init controller")
+
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
+ worker_status: dict):
+ if worker_name not in self.worker_info:
+ logger.info(f"Register a new worker: {worker_name}")
+ else:
+ logger.info(f"Register an existing worker: {worker_name}")
+
+ if not worker_status:
+ worker_status = self.get_worker_status(worker_name)
+ if not worker_status:
+ return False
+
+ self.worker_info[worker_name] = WorkerInfo(
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
+ check_heart_beat, time.time())
+
+ logger.info(f"Register done: {worker_name}, {worker_status}")
+ return True
+
+ def get_worker_status(self, worker_name: str):
+ try:
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
+ except requests.exceptions.RequestException as e:
+ logger.error(f"Get status fails: {worker_name}, {e}")
+ return None
+
+ if r.status_code != 200:
+ logger.error(f"Get status fails: {worker_name}, {r}")
+ return None
+
+ return r.json()
+
+ def remove_worker(self, worker_name: str):
+ del self.worker_info[worker_name]
+
+ def refresh_all_workers(self):
+ old_info = dict(self.worker_info)
+ self.worker_info = {}
+
+ for w_name, w_info in old_info.items():
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
+ logger.info(f"Remove stale worker: {w_name}")
+
+ def list_models(self):
+ model_names = set()
+
+ for w_name, w_info in self.worker_info.items():
+ model_names.update(w_info.model_names)
+
+ return list(model_names)
+
+ def get_worker_address(self, model_name: str):
+ if self.dispatch_method == DispatchMethod.LOTTERY:
+ worker_names = []
+ worker_speeds = []
+ for w_name, w_info in self.worker_info.items():
+ if model_name in w_info.model_names:
+ worker_names.append(w_name)
+ worker_speeds.append(w_info.speed)
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
+ norm = np.sum(worker_speeds)
+ if norm < 1e-4:
+ return ""
+ worker_speeds = worker_speeds / norm
+ if True: # Directly return address
+ pt = np.random.choice(np.arange(len(worker_names)),
+ p=worker_speeds)
+ worker_name = worker_names[pt]
+ return worker_name
+
+ # Check status before returning
+ while True:
+ pt = np.random.choice(np.arange(len(worker_names)),
+ p=worker_speeds)
+ worker_name = worker_names[pt]
+
+ if self.get_worker_status(worker_name):
+ break
+ else:
+ self.remove_worker(worker_name)
+ worker_speeds[pt] = 0
+ norm = np.sum(worker_speeds)
+ if norm < 1e-4:
+ return ""
+ worker_speeds = worker_speeds / norm
+ continue
+ return worker_name
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
+ worker_names = []
+ worker_qlen = []
+ for w_name, w_info in self.worker_info.items():
+ if model_name in w_info.model_names:
+ worker_names.append(w_name)
+ worker_qlen.append(w_info.queue_length / w_info.speed)
+ if len(worker_names) == 0:
+ return ""
+ min_index = np.argmin(worker_qlen)
+ w_name = worker_names[min_index]
+ self.worker_info[w_name].queue_length += 1
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
+ return w_name
+ else:
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
+
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
+ if worker_name not in self.worker_info:
+ logger.info(f"Receive unknown heart beat. {worker_name}")
+ return False
+
+ self.worker_info[worker_name].queue_length = queue_length
+ self.worker_info[worker_name].last_heart_beat = time.time()
+ logger.info(f"Receive heart beat. {worker_name}")
+ return True
+
+ def remove_stable_workers_by_expiration(self):
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
+ to_delete = []
+ for worker_name, w_info in self.worker_info.items():
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
+ to_delete.append(worker_name)
+
+ for worker_name in to_delete:
+ self.remove_worker(worker_name)
+
+ def worker_api_generate_stream(self, params):
+ worker_addr = self.get_worker_address(params["model"])
+ if not worker_addr:
+ logger.info(f"no worker: {params['model']}")
+ ret = {
+ "text": server_error_msg,
+ "error_code": 2,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+ try:
+ response = requests.post(worker_addr + "/worker_generate_stream",
+ json=params, stream=True, timeout=5)
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ yield chunk + b"\0"
+ except requests.exceptions.RequestException as e:
+ logger.info(f"worker timeout: {worker_addr}")
+ ret = {
+ "text": server_error_msg,
+ "error_code": 3,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+
+ # Let the controller act as a worker to achieve hierarchical
+ # management. This can be used to connect isolated sub networks.
+ def worker_api_get_status(self):
+ model_names = set()
+ speed = 0
+ queue_length = 0
+
+ for w_name in self.worker_info:
+ worker_status = self.get_worker_status(w_name)
+ if worker_status is not None:
+ model_names.update(worker_status["model_names"])
+ speed += worker_status["speed"]
+ queue_length += worker_status["queue_length"]
+
+ return {
+ "model_names": list(model_names),
+ "speed": speed,
+ "queue_length": queue_length,
+ }
+
+
+app = FastAPI()
+
+
+@app.post("/register_worker")
+async def register_worker(request: Request):
+ data = await request.json()
+ controller.register_worker(
+ data["worker_name"], data["check_heart_beat"],
+ data.get("worker_status", None))
+
+
+@app.post("/refresh_all_workers")
+async def refresh_all_workers():
+ models = controller.refresh_all_workers()
+
+
+@app.post("/list_models")
+async def list_models():
+ models = controller.list_models()
+ return {"models": models}
+
+
+@app.post("/get_worker_address")
+async def get_worker_address(request: Request):
+ data = await request.json()
+ addr = controller.get_worker_address(data["model"])
+ return {"address": addr}
+
+
+@app.post("/receive_heart_beat")
+async def receive_heart_beat(request: Request):
+ data = await request.json()
+ exist = controller.receive_heart_beat(
+ data["worker_name"], data["queue_length"])
+ return {"exist": exist}
+
+
+@app.post("/worker_generate_stream")
+async def worker_api_generate_stream(request: Request):
+ params = await request.json()
+ generator = controller.worker_api_generate_stream(params)
+ return StreamingResponse(generator)
+
+
+@app.post("/worker_get_status")
+async def worker_api_get_status(request: Request):
+ return controller.worker_api_get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21001)
+ parser.add_argument("--dispatch-method", type=str, choices=[
+ "lottery", "shortest_queue"], default="shortest_queue")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ controller = Controller(args.dispatch_method)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/gradio_web_server.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/gradio_web_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c016cc0261c957c850e3a3bb1dbbcbed2976efc
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/gradio_web_server.py
@@ -0,0 +1,436 @@
+import argparse
+import datetime
+import json
+import os
+import time
+
+import gradio as gr
+import requests
+
+from llava.conversation import (default_conversation, conv_templates,
+ SeparatorStyle)
+from llava.constants import LOGDIR
+from llava.utils import (build_logger, server_error_msg,
+ violates_moderation, moderation_msg)
+import hashlib
+
+
+logger = build_logger("gradio_web_server", "gradio_web_server.log")
+
+headers = {"User-Agent": "LLaVA Client"}
+
+no_change_btn = gr.Button()
+enable_btn = gr.Button(interactive=True)
+disable_btn = gr.Button(interactive=False)
+
+priority = {
+ "vicuna-13b": "aaaaaaa",
+ "koala-13b": "aaaaaab",
+}
+
+
+def get_conv_log_filename():
+ t = datetime.datetime.now()
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
+ return name
+
+
+def get_model_list():
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
+ assert ret.status_code == 200
+ ret = requests.post(args.controller_url + "/list_models")
+ models = ret.json()["models"]
+ models.sort(key=lambda x: priority.get(x, x))
+ logger.info(f"Models: {models}")
+ return models
+
+
+get_window_url_params = """
+function() {
+ const params = new URLSearchParams(window.location.search);
+ url_params = Object.fromEntries(params);
+ console.log(url_params);
+ return url_params;
+ }
+"""
+
+
+def load_demo(url_params, request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
+
+ dropdown_update = gr.Dropdown(visible=True)
+ if "model" in url_params:
+ model = url_params["model"]
+ if model in models:
+ dropdown_update = gr.Dropdown(value=model, visible=True)
+
+ state = default_conversation.copy()
+ return state, dropdown_update
+
+
+def load_demo_refresh_model_list(request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}")
+ models = get_model_list()
+ state = default_conversation.copy()
+ dropdown_update = gr.Dropdown(
+ choices=models,
+ value=models[0] if len(models) > 0 else "",
+ )
+
+ return state, dropdown_update
+
+
+def vote_last_response(state, vote_type, model_selector, request: gr.Request):
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(time.time(), 4),
+ "type": vote_type,
+ "model": model_selector,
+ "state": state.dict(),
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+def upvote_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"upvote. ip: {request.client.host}")
+ vote_last_response(state, "upvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def downvote_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"downvote. ip: {request.client.host}")
+ vote_last_response(state, "downvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def flag_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"flag. ip: {request.client.host}")
+ vote_last_response(state, "flag", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def regenerate(state, image_process_mode, request: gr.Request):
+ logger.info(f"regenerate. ip: {request.client.host}")
+ state.messages[-1][-1] = None
+ prev_human_msg = state.messages[-2]
+ if type(prev_human_msg[1]) in (tuple, list):
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
+ state.skip_next = False
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
+
+
+def clear_history(request: gr.Request):
+ logger.info(f"clear_history. ip: {request.client.host}")
+ state = default_conversation.copy()
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
+
+
+def add_text(state, text, image, image_process_mode, request: gr.Request):
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
+ if len(text) <= 0 and image is None:
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
+ if args.moderate:
+ flagged = violates_moderation(text)
+ if flagged:
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
+ no_change_btn,) * 5
+
+ text = text[:1536] # Hard cut-off
+ if image is not None:
+ text = text[:1200] # Hard cut-off for images
+ if '' not in text:
+ # text = '' + text
+ text = text + '\n'
+ text = (text, image, image_process_mode)
+ state = default_conversation.copy()
+ state.append_message(state.roles[0], text)
+ state.append_message(state.roles[1], None)
+ state.skip_next = False
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
+
+
+def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
+ logger.info(f"http_bot. ip: {request.client.host}")
+ start_tstamp = time.time()
+ model_name = model_selector
+
+ if state.skip_next:
+ # This generate call is skipped due to invalid inputs
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
+ return
+
+ if len(state.messages) == state.offset + 2:
+ # First round of conversation
+ if "llava" in model_name.lower():
+ if 'llama-2' in model_name.lower():
+ template_name = "llava_llama_2"
+ elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
+ if 'orca' in model_name.lower():
+ template_name = "mistral_orca"
+ elif 'hermes' in model_name.lower():
+ template_name = "chatml_direct"
+ else:
+ template_name = "mistral_instruct"
+ elif 'llava-v1.6-34b' in model_name.lower():
+ template_name = "chatml_direct"
+ elif "v1" in model_name.lower():
+ if 'mmtag' in model_name.lower():
+ template_name = "v1_mmtag"
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
+ template_name = "v1_mmtag"
+ else:
+ template_name = "llava_v1"
+ elif "mpt" in model_name.lower():
+ template_name = "mpt"
+ else:
+ if 'mmtag' in model_name.lower():
+ template_name = "v0_mmtag"
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
+ template_name = "v0_mmtag"
+ else:
+ template_name = "llava_v0"
+ elif "mpt" in model_name:
+ template_name = "mpt_text"
+ elif "llama-2" in model_name:
+ template_name = "llama_2"
+ else:
+ template_name = "vicuna_v1"
+ new_state = conv_templates[template_name].copy()
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
+ new_state.append_message(new_state.roles[1], None)
+ state = new_state
+
+ # Query worker address
+ controller_url = args.controller_url
+ ret = requests.post(controller_url + "/get_worker_address",
+ json={"model": model_name})
+ worker_addr = ret.json()["address"]
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
+
+ # No available worker
+ if worker_addr == "":
+ state.messages[-1][-1] = server_error_msg
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+
+ # Construct prompt
+ prompt = state.get_prompt()
+
+ all_images = state.get_images(return_pil=True)
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
+ for image, hash in zip(all_images, all_image_hash):
+ t = datetime.datetime.now()
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
+ if not os.path.isfile(filename):
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ image.save(filename)
+
+ # Make requests
+ pload = {
+ "model": model_name,
+ "prompt": prompt,
+ "temperature": float(temperature),
+ "top_p": float(top_p),
+ "max_new_tokens": min(int(max_new_tokens), 1536),
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
+ }
+ logger.info(f"==== request ====\n{pload}")
+
+ pload['images'] = state.get_images()
+
+ state.messages[-1][-1] = "▌"
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+
+ try:
+ # Stream output
+ response = requests.post(worker_addr + "/worker_generate_stream",
+ headers=headers, json=pload, stream=True, timeout=10)
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode())
+ if data["error_code"] == 0:
+ output = data["text"][len(prompt):].strip()
+ state.messages[-1][-1] = output + "▌"
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+ else:
+ output = data["text"] + f" (error_code: {data['error_code']})"
+ state.messages[-1][-1] = output
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+ time.sleep(0.03)
+ except requests.exceptions.RequestException as e:
+ state.messages[-1][-1] = server_error_msg
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
+
+ finish_tstamp = time.time()
+ logger.info(f"{output}")
+
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(finish_tstamp, 4),
+ "type": "chat",
+ "model": model_name,
+ "start": round(start_tstamp, 4),
+ "finish": round(finish_tstamp, 4),
+ "state": state.dict(),
+ "images": all_image_hash,
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+block_css = """
+
+#buttons button {
+ min-width: min(120px,100%);
+}
+
+"""
+
+def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
+ with gr.Blocks(title="Hunyuan", theme=gr.themes.Default(), css=block_css) as demo:
+ state = gr.State()
+
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ with gr.Row(elem_id="model_selector_row"):
+ model_selector = gr.Dropdown(
+ choices=models,
+ value=models[0] if len(models) > 0 else "",
+ interactive=True,
+ show_label=False,
+ container=False)
+
+ imagebox = gr.Image(type="pil")
+ image_process_mode = gr.Radio(
+ ["Crop", "Resize", "Pad", "Default"],
+ value="Default",
+ label="Preprocess for non-square image", visible=False)
+
+ if cur_dir is None:
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
+
+ with gr.Accordion("Parameters", open=False) as parameter_row:
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
+
+ with gr.Column(scale=8):
+ chatbot = gr.Chatbot(
+ elem_id="chatbot",
+ label="Hunyuan",
+ height=650,
+ layout="panel",
+ )
+ with gr.Row():
+ with gr.Column(scale=8):
+ textbox.render()
+ with gr.Column(scale=1, min_width=50):
+ submit_btn = gr.Button(value="Send", variant="primary")
+ with gr.Row(elem_id="buttons") as button_row:
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
+
+
+ url_params = gr.JSON(visible=False)
+
+
+ btn_list = [regenerate_btn, clear_btn]
+
+
+ regenerate_btn.click(
+ regenerate,
+ [state, image_process_mode],
+ [state, chatbot, textbox, imagebox] + btn_list
+ ).then(
+ http_bot,
+ [state, model_selector, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list,
+ # concurrency_limit=concurrency_count
+ )
+
+ clear_btn.click(
+ clear_history,
+ None,
+ [state, chatbot, textbox, imagebox] + btn_list,
+ queue=False
+ )
+
+ textbox.submit(
+ add_text,
+ [state, textbox, imagebox, image_process_mode],
+ [state, chatbot, textbox, imagebox] + btn_list,
+ queue=False
+ ).then(
+ http_bot,
+ [state, model_selector, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list,
+
+ )
+
+ submit_btn.click(
+ add_text,
+ [state, textbox, imagebox, image_process_mode],
+ [state, chatbot, textbox, imagebox] + btn_list
+ ).then(
+ http_bot,
+ [state, model_selector, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list,
+
+ )
+
+ if args.model_list_mode == "once":
+ demo.load(
+ load_demo,
+ [url_params],
+ [state, model_selector],
+ js=get_window_url_params
+ )
+ elif args.model_list_mode == "reload":
+ demo.load(
+ load_demo_refresh_model_list,
+ None,
+ [state, model_selector],
+ queue=False
+ )
+ else:
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="0.0.0.0")
+ parser.add_argument("--port", type=int)
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
+ parser.add_argument("--concurrency-count", type=int, default=16)
+ parser.add_argument("--model-list-mode", type=str, default="once",
+ choices=["once", "reload"])
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument("--moderate", action="store_true")
+ parser.add_argument("--embed", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ models = get_model_list()
+
+ logger.info(args)
+ demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
+ demo.queue(
+ api_open=False
+ ).launch(
+ server_name=args.host,
+ server_port=args.port,
+ share=args.share
+ )
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/model_worker.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/model_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..9144329893c51f402ff2e2f65d9fb7baf177bd52
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/model_worker.py
@@ -0,0 +1,288 @@
+"""
+A model worker executes the model.
+"""
+import argparse
+import asyncio
+import json
+import time
+import threading
+import uuid
+
+from fastapi import FastAPI, Request, BackgroundTasks
+from fastapi.responses import StreamingResponse
+import requests
+import torch
+import uvicorn
+from functools import partial
+
+from llava.constants import WORKER_HEART_BEAT_INTERVAL
+from llava.utils import (build_logger, server_error_msg,
+ pretty_print_semaphore)
+from llava.model.builder import load_pretrained_model
+from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from transformers import TextIteratorStreamer
+from threading import Thread
+
+
+GB = 1 << 30
+
+worker_id = str(uuid.uuid4())[:6]
+logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
+global_counter = 0
+
+model_semaphore = None
+
+
+def heart_beat_worker(controller):
+
+ while True:
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
+ controller.send_heart_beat()
+
+
+class ModelWorker:
+ def __init__(self, controller_addr, worker_addr,
+ worker_id, no_register,
+ model_path, model_base, model_name,
+ load_8bit, load_4bit, device, use_flash_attn=False):
+ self.controller_addr = controller_addr
+ self.worker_addr = worker_addr
+ self.worker_id = worker_id
+ if model_path.endswith("/"):
+ model_path = model_path[:-1]
+ if model_name is None:
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith('checkpoint-'):
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
+ else:
+ self.model_name = model_paths[-1]
+ else:
+ self.model_name = model_name
+
+ self.device = device
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn)
+ self.is_multimodal = 'llava' in self.model_name.lower()
+
+ if not no_register:
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_worker, args=(self,), daemon=True)
+ self.heart_beat_thread.start()
+
+ def register_to_controller(self):
+ logger.info("Register to controller")
+
+ url = self.controller_addr + "/register_worker"
+ data = {
+ "worker_name": self.worker_addr,
+ "check_heart_beat": True,
+ "worker_status": self.get_status()
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
+
+ def send_heart_beat(self):
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
+ f"global_counter: {global_counter}")
+
+ url = self.controller_addr + "/receive_heart_beat"
+
+ while True:
+ try:
+ ret = requests.post(url, json={
+ "worker_name": self.worker_addr,
+ "queue_length": self.get_queue_length()}, timeout=5)
+ exist = ret.json()["exist"]
+ break
+ except requests.exceptions.RequestException as e:
+ logger.error(f"heart beat error: {e}")
+ time.sleep(5)
+
+ if not exist:
+ self.register_to_controller()
+
+ def get_queue_length(self):
+ if model_semaphore is None:
+ return 0
+ else:
+ return args.limit_model_concurrency - model_semaphore._value + (len(
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
+
+ def get_status(self):
+ return {
+ "model_names": [self.model_name],
+ "speed": 1,
+ "queue_length": self.get_queue_length(),
+ }
+
+ @torch.inference_mode()
+ def generate_stream(self, params):
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
+
+ prompt = params["prompt"]
+ ori_prompt = prompt
+ images = params.get("images", None)
+ num_image_tokens = 0
+ if images is not None and len(images) > 0 and self.is_multimodal:
+ if len(images) > 0:
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
+ raise ValueError("Number of images does not match number of tokens in prompt")
+
+ images = [load_image_from_base64(image) for image in images]
+ image_sizes = [image.size for image in images]
+ images = process_images(images, image_processor, model.config)
+
+ if type(images) is list:
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
+ else:
+ images = images.to(self.model.device, dtype=torch.float16)
+
+ replace_token = DEFAULT_IMAGE_TOKEN
+ if getattr(self.model.config, 'mm_use_im_start_end', False):
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
+ else:
+ images = None
+ image_sizes = None
+ image_args = {"images": images, "image_sizes": image_sizes}
+ else:
+ images = None
+ image_args = {}
+
+ temperature = float(params.get("temperature", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
+ stop_str = params.get("stop", None)
+ do_sample = True if temperature > 0.001 else False
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
+ keywords = [stop_str]
+ # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
+
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
+
+ if max_new_tokens < 1:
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
+ return
+
+ thread = Thread(target=model.generate, kwargs=dict(
+ inputs=input_ids,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ max_new_tokens=max_new_tokens,
+ streamer=streamer,
+ use_cache=True,
+ **image_args
+ ))
+ thread.start()
+
+ generated_text = ori_prompt
+ for new_text in streamer:
+ generated_text += new_text
+ if generated_text.endswith(stop_str):
+ generated_text = generated_text[:-len(stop_str)]
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
+
+ def generate_stream_gate(self, params):
+ try:
+ for x in self.generate_stream(params):
+ yield x
+ except ValueError as e:
+ print("Caught ValueError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except torch.cuda.CudaError as e:
+ print("Caught torch.cuda.CudaError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except Exception as e:
+ print("Caught Unknown Error", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+
+app = FastAPI()
+
+
+def release_model_semaphore(fn=None):
+ model_semaphore.release()
+ if fn is not None:
+ fn()
+
+
+@app.post("/worker_generate_stream")
+async def generate_stream(request: Request):
+ global model_semaphore, global_counter
+ global_counter += 1
+ params = await request.json()
+
+ if model_semaphore is None:
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
+ await model_semaphore.acquire()
+ worker.send_heart_beat()
+ generator = worker.generate_stream_gate(params)
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_get_status")
+async def get_status(request: Request):
+ return worker.get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str,
+ default="http://localhost:21002")
+ parser.add_argument("--controller-address", type=str,
+ default="http://localhost:21001")
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--model-name", type=str)
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
+ parser.add_argument("--stream-interval", type=int, default=1)
+ parser.add_argument("--no-register", action="store_true")
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--use-flash-attn", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ if args.multi_modal:
+ logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
+
+ worker = ModelWorker(args.controller_address,
+ args.worker_address,
+ worker_id,
+ args.no_register,
+ args.model_path,
+ args.model_base,
+ args.model_name,
+ args.load_8bit,
+ args.load_4bit,
+ args.device,
+ use_flash_attn=args.use_flash_attn)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/register_worker.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/register_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c2c40295e0351f25709ba25554c9329f15bf0d2
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/register_worker.py
@@ -0,0 +1,26 @@
+"""
+Manually register workers.
+
+Usage:
+python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
+"""
+
+import argparse
+
+import requests
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--controller-address", type=str)
+ parser.add_argument("--worker-name", type=str)
+ parser.add_argument("--check-heart-beat", action="store_true")
+ args = parser.parse_args()
+
+ url = args.controller_address + "/register_worker"
+ data = {
+ "worker_name": args.worker_name,
+ "check_heart_beat": args.check_heart_beat,
+ "worker_status": None,
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/sglang_worker.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/sglang_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3297b7c295abddedfaac7f6fbe882d7b672487d
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/sglang_worker.py
@@ -0,0 +1,244 @@
+"""
+A model worker executes the model.
+"""
+import argparse
+import asyncio
+from concurrent.futures import ThreadPoolExecutor
+import json
+import time
+import threading
+import uuid
+
+from fastapi import FastAPI, Request, BackgroundTasks
+from fastapi.responses import StreamingResponse
+import requests
+import re
+import uvicorn
+from functools import partial
+
+from llava.constants import WORKER_HEART_BEAT_INTERVAL
+from llava.utils import (build_logger, server_error_msg,
+ pretty_print_semaphore)
+from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square
+from llava.constants import DEFAULT_IMAGE_TOKEN
+
+import sglang as sgl
+from sglang.backend.runtime_endpoint import RuntimeEndpoint
+
+
+GB = 1 << 30
+
+worker_id = str(uuid.uuid4())[:6]
+logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
+global_counter = 0
+
+model_semaphore = None
+
+
+def heart_beat_worker(controller):
+ while True:
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
+ controller.send_heart_beat()
+
+
+@sgl.function
+def pipeline(s, prompt, max_tokens):
+ for p in prompt:
+ if type(p) is str:
+ s += p
+ else:
+ s += sgl.image(p)
+ s += sgl.gen("response", max_tokens=max_tokens)
+
+
+class ModelWorker:
+ def __init__(self, controller_addr, worker_addr, sgl_endpoint,
+ worker_id, no_register, model_name):
+ self.controller_addr = controller_addr
+ self.worker_addr = worker_addr
+ self.worker_id = worker_id
+
+ # Select backend
+ backend = RuntimeEndpoint(sgl_endpoint)
+ sgl.set_default_backend(backend)
+ model_path = backend.model_info["model_path"]
+
+ if model_path.endswith("/"):
+ model_path = model_path[:-1]
+ if model_name is None:
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith('checkpoint-'):
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
+ else:
+ self.model_name = model_paths[-1]
+ else:
+ self.model_name = model_name
+
+ logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...")
+
+ if not no_register:
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_worker, args=(self,), daemon=True)
+ self.heart_beat_thread.start()
+
+ def register_to_controller(self):
+ logger.info("Register to controller")
+
+ url = self.controller_addr + "/register_worker"
+ data = {
+ "worker_name": self.worker_addr,
+ "check_heart_beat": True,
+ "worker_status": self.get_status()
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
+
+ def send_heart_beat(self):
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
+ f"global_counter: {global_counter}")
+
+ url = self.controller_addr + "/receive_heart_beat"
+
+ while True:
+ try:
+ ret = requests.post(url, json={
+ "worker_name": self.worker_addr,
+ "queue_length": self.get_queue_length()}, timeout=5)
+ exist = ret.json()["exist"]
+ break
+ except requests.exceptions.RequestException as e:
+ logger.error(f"heart beat error: {e}")
+ time.sleep(5)
+
+ if not exist:
+ self.register_to_controller()
+
+ def get_queue_length(self):
+ if model_semaphore is None:
+ return 0
+ else:
+ return args.limit_model_concurrency - model_semaphore._value + (len(
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
+
+ def get_status(self):
+ return {
+ "model_names": [self.model_name],
+ "speed": 1,
+ "queue_length": self.get_queue_length(),
+ }
+
+ async def generate_stream(self, params):
+ ori_prompt = prompt = params["prompt"]
+ images = params.get("images", None)
+ if images is not None and len(images) > 0:
+ if len(images) > 0:
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
+ raise ValueError("Number of images does not match number of tokens in prompt")
+
+ images = [load_image_from_base64(image) for image in images]
+
+ # FIXME: for image-start/end token
+ # replace_token = DEFAULT_IMAGE_TOKEN
+ # if getattr(self.model.config, 'mm_use_im_start_end', False):
+ # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
+ prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN)
+ prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN)
+ prompt = []
+ for i in range(len(prompt_split)):
+ prompt.append(prompt_split[i])
+ if i < len(images):
+ prompt.append(images[i])
+ else:
+ prompt = [prompt]
+
+ temperature = float(params.get("temperature", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ # max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
+ stop_str = params.get("stop", None)
+ stop_str = [stop_str] if stop_str is not None else None
+
+ print({'prompt': prompt, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p})
+ state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True)
+
+ generated_text = ori_prompt
+ async for text_outputs in state.text_async_iter(var_name="response"):
+ generated_text += text_outputs
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
+
+ async def generate_stream_gate(self, params):
+ try:
+ async for x in self.generate_stream(params):
+ yield x
+ except ValueError as e:
+ print("Caught ValueError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except Exception as e:
+ print("Caught Unknown Error", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+
+app = FastAPI()
+
+
+def release_model_semaphore(fn=None):
+ model_semaphore.release()
+ if fn is not None:
+ fn()
+
+
+@app.post("/worker_generate_stream")
+async def generate_stream(request: Request):
+ global model_semaphore, global_counter
+ global_counter += 1
+ params = await request.json()
+
+ if model_semaphore is None:
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
+ await model_semaphore.acquire()
+ worker.send_heart_beat()
+ generator = worker.generate_stream_gate(params)
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_get_status")
+async def get_status(request: Request):
+ return worker.get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str,
+ default="http://localhost:21002")
+ parser.add_argument("--controller-address", type=str,
+ default="http://localhost:21001")
+ parser.add_argument("--model-name", type=str)
+ parser.add_argument("--sgl-endpoint", type=str)
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
+ parser.add_argument("--stream-interval", type=int, default=1)
+ parser.add_argument("--no-register", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ worker = ModelWorker(args.controller_address,
+ args.worker_address,
+ args.sgl_endpoint,
+ worker_id,
+ args.no_register,
+ args.model_name)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/test_message.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/test_message.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b090faed0e630b03b2294545050f1f4f5032cad
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/serve/test_message.py
@@ -0,0 +1,62 @@
+import argparse
+import json
+
+import requests
+
+from llava.conversation import default_conversation
+
+
+def main():
+ if args.worker_address:
+ worker_addr = args.worker_address
+ else:
+ controller_addr = args.controller_address
+ ret = requests.post(controller_addr + "/refresh_all_workers")
+ ret = requests.post(controller_addr + "/list_models")
+ models = ret.json()["models"]
+ models.sort()
+ print(f"Models: {models}")
+
+ ret = requests.post(controller_addr + "/get_worker_address",
+ json={"model": args.model_name})
+ worker_addr = ret.json()["address"]
+ print(f"worker_addr: {worker_addr}")
+
+ if worker_addr == "":
+ return
+
+ conv = default_conversation.copy()
+ conv.append_message(conv.roles[0], args.message)
+ prompt = conv.get_prompt()
+
+ headers = {"User-Agent": "LLaVA Client"}
+ pload = {
+ "model": args.model_name,
+ "prompt": prompt,
+ "max_new_tokens": args.max_new_tokens,
+ "temperature": 0.7,
+ "stop": conv.sep,
+ }
+ response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
+ json=pload, stream=True)
+
+ print(prompt.replace(conv.sep, "\n"), end="")
+ for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode("utf-8"))
+ output = data["text"].split(conv.sep)[-1]
+ print(output, end="\r")
+ print("")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
+ parser.add_argument("--worker-address", type=str)
+ parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
+ parser.add_argument("--max-new-tokens", type=int, default=32)
+ parser.add_argument("--message", type=str, default=
+ "Tell me a story with more than 1000 words.")
+ args = parser.parse_args()
+
+ main()
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/utils.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4006cf917e26c365080b0844c56fab78c48457c0
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/llava/utils.py
@@ -0,0 +1,126 @@
+import datetime
+import logging
+import logging.handlers
+import os
+import sys
+
+import requests
+
+from llava.constants import LOGDIR
+
+server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
+moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+
+handler = None
+
+
+def build_logger(logger_name, logger_filename):
+ global handler
+
+ formatter = logging.Formatter(
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ # Set the format of root handlers
+ if not logging.getLogger().handlers:
+ logging.basicConfig(level=logging.INFO)
+ logging.getLogger().handlers[0].setFormatter(formatter)
+
+ # Redirect stdout and stderr to loggers
+ stdout_logger = logging.getLogger("stdout")
+ stdout_logger.setLevel(logging.INFO)
+ sl = StreamToLogger(stdout_logger, logging.INFO)
+ sys.stdout = sl
+
+ stderr_logger = logging.getLogger("stderr")
+ stderr_logger.setLevel(logging.ERROR)
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
+ sys.stderr = sl
+
+ # Get logger
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.INFO)
+
+ # Add a file handler for all loggers
+ if handler is None:
+ os.makedirs(LOGDIR, exist_ok=True)
+ filename = os.path.join(LOGDIR, logger_filename)
+ handler = logging.handlers.TimedRotatingFileHandler(
+ filename, when='D', utc=True, encoding='UTF-8')
+ handler.setFormatter(formatter)
+
+ for name, item in logging.root.manager.loggerDict.items():
+ if isinstance(item, logging.Logger):
+ item.addHandler(handler)
+
+ return logger
+
+
+class StreamToLogger(object):
+ """
+ Fake file-like stream object that redirects writes to a logger instance.
+ """
+ def __init__(self, logger, log_level=logging.INFO):
+ self.terminal = sys.stdout
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ''
+
+ def __getattr__(self, attr):
+ return getattr(self.terminal, attr)
+
+ def write(self, buf):
+ temp_linebuf = self.linebuf + buf
+ self.linebuf = ''
+ for line in temp_linebuf.splitlines(True):
+ # From the io.TextIOWrapper docs:
+ # On output, if newline is None, any '\n' characters written
+ # are translated to the system default line separator.
+ # By default sys.stdout.write() expects '\n' newlines and then
+ # translates them so this is still cross platform.
+ if line[-1] == '\n':
+ self.logger.log(self.log_level, line.rstrip())
+ else:
+ self.linebuf += line
+
+ def flush(self):
+ if self.linebuf != '':
+ self.logger.log(self.log_level, self.linebuf.rstrip())
+ self.linebuf = ''
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def violates_moderation(text):
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ url = "https://api.openai.com/v1/moderations"
+ headers = {"Content-Type": "application/json",
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
+ text = text.replace("\n", "")
+ data = "{" + '"input": ' + f'"{text}"' + "}"
+ data = data.encode("utf-8")
+ try:
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
+ flagged = ret.json()["results"][0]["flagged"]
+ except requests.exceptions.RequestException as e:
+ flagged = False
+ except KeyError as e:
+ flagged = False
+
+ return flagged
+
+
+def pretty_print_semaphore(semaphore):
+ if semaphore is None:
+ return "None"
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/mllm/make_csv.py b/PyTorch/built-in/mlm/HunyuanDiT/mllm/make_csv.py
new file mode 100644
index 0000000000000000000000000000000000000000..66678090997df93a97b559e608952d06b5765299
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/mllm/make_csv.py
@@ -0,0 +1,15 @@
+import os
+import argparse
+import pandas as pd
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--img_dir', type=str, default='mllm/images')
+ parser.add_argument('--input_file', type=str, default='mllm/images/demo.csv')
+ args = parser.parse_args()
+
+ df = pd.DataFrame(columns=['img_path'])
+ df['img_path'] = [os.path.join(args.img_dir, fn) for fn in os.listdir(args.img_dir) if fn.endswith(".png")]
+
+ df.to_csv(args.input_file, index=False)
+ print("csv file saved to: ", args.input_file)
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/npu_sample.sh b/PyTorch/built-in/mlm/HunyuanDiT/npu_sample.sh
new file mode 100644
index 0000000000000000000000000000000000000000..df4d8f0fa7c2c440e622a217f13e4a57517d5a14
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/npu_sample.sh
@@ -0,0 +1,8 @@
+source /home/l50041210/cann-b20/ascend-toolkit/set_env.sh
+export ASCEND_RT_VISIBLE_DEVICES=0
+#将Host日志输出到串口,0-关闭/1-开启
+export ASCEND_SLOG_PRINT_TO_STDOUT=0
+#设置默认日志级别,0-debug/1-info/2-warning/3-error
+export ASCEND_GLOBAL_LOG_LEVEL=3
+
+python sample_t2i.py --infer-mode fa --prompt "渔舟唱晚" --image-size 1280 768 --no-enhance
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/npu_train.sh b/PyTorch/built-in/mlm/HunyuanDiT/npu_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..57e991e0dc377a14ab7af0f5b3db3763b123b3cb
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/npu_train.sh
@@ -0,0 +1,7 @@
+source /home/l50041210/cann-b020/ascend-toolkit/set_env.sh
+#将Host日志输出到串口,0-关闭/1-开启
+export ASCEND_SLOG_PRINT_TO_STDOUT=0
+#设置默认日志级别,0-debug/1-info/2-warning/3-error
+export ASCEND_GLOBAL_LOG_LEVEL=3
+
+PYTHONPATH=./ sh hydit/train.sh
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/requirements.txt b/PyTorch/built-in/mlm/HunyuanDiT/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8de35fd32eb220912bf782b94dad64beb2776e20
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/requirements.txt
@@ -0,0 +1,22 @@
+--extra-index-url https://pypi.ngc.nvidia.com
+timm==0.9.5
+diffusers==0.21.2
+peft==0.10.0
+protobuf==3.19.0
+torchvision==0.14.1
+transformers==4.39.1
+peft==0.10.0
+accelerate==0.29.3
+loguru==0.7.2
+einops==0.7.0
+sentencepiece==0.1.99
+cuda-python==11.7.1
+onnxruntime==1.12.1
+onnx==1.12.0
+nvidia-pyindex==1.0.9
+onnx-graphsurgeon==0.3.27
+polygraphy==0.47.1
+pandas==2.0.3
+gradio==3.50.2
+deepspeed==0.6.3
+pyarrow==16.1.0
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/sample_controlnet.py b/PyTorch/built-in/mlm/HunyuanDiT/sample_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c70d3eede8e586e23327004dd62fdfbae08349f3
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/sample_controlnet.py
@@ -0,0 +1,89 @@
+from pathlib import Path
+
+from loguru import logger
+
+from mllm.dialoggen_demo import DialogGen
+from hydit.config import get_args
+from hydit.inference_controlnet import End2End
+
+from torchvision import transforms as T
+import numpy as np
+
+norm_transform = T.Compose(
+ [
+ T.ToTensor(),
+ T.Normalize([0.5], [0.5]),
+ ]
+ )
+
+from PIL import Image
+def inferencer():
+ args = get_args()
+ models_root_path = Path(args.model_root)
+ if not models_root_path.exists():
+ raise ValueError(f"`models_root` not exists: {models_root_path}")
+
+ # Load models
+ gen = End2End(args, models_root_path)
+
+ # Try to enhance prompt
+ if args.enhance:
+ logger.info("Loading DialogGen model (for prompt enhancement)...")
+ enhancer = DialogGen(str(models_root_path / "dialoggen"), args.load_4bit)
+ logger.info("DialogGen model loaded.")
+ else:
+ enhancer = None
+
+ return args, gen, enhancer
+
+
+if __name__ == "__main__":
+
+ args, gen, enhancer = inferencer()
+
+ if enhancer:
+ logger.info("Prompt Enhancement...")
+ success, enhanced_prompt = enhancer(args.prompt)
+ if not success:
+ logger.info("Sorry, the prompt is not compliant, refuse to draw.")
+ exit()
+ logger.info(f"Enhanced prompt: {enhanced_prompt}")
+ else:
+ enhanced_prompt = None
+
+ # Run inference
+ logger.info("Generating images...")
+ height, width = args.image_size
+
+ condition = Image.open(args.condition_image_path).convert('RGB').resize((width, height))
+ image = norm_transform(condition)
+ image = image.unsqueeze(0).cuda()
+
+ results = gen.predict(args.prompt,
+ height=height,
+ width=width,
+ image=image,
+ seed=args.seed,
+ enhanced_prompt=enhanced_prompt,
+ negative_prompt=args.negative,
+ infer_steps=args.infer_steps,
+ guidance_scale=args.cfg_scale,
+ batch_size=args.batch_size,
+ src_size_cond=args.size_cond,
+ )
+ images = results['images']
+
+ # Save images
+ save_dir = Path('results')
+ save_dir.mkdir(exist_ok=True)
+ # Find the first available index
+ all_files = list(save_dir.glob('*.png'))
+ if all_files:
+ start = max([int(f.stem) for f in all_files]) + 1
+ else:
+ start = 0
+
+ for idx, pil_img in enumerate(images):
+ save_path = save_dir / f"{idx + start}.png"
+ pil_img.save(save_path)
+ logger.info(f"Save to {save_path}")
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/sample_t2i.py b/PyTorch/built-in/mlm/HunyuanDiT/sample_t2i.py
new file mode 100644
index 0000000000000000000000000000000000000000..1575fe980e522a887fbc2a73d67c65a4430978f7
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/sample_t2i.py
@@ -0,0 +1,72 @@
+from pathlib import Path
+
+from loguru import logger
+
+from mllm.dialoggen_demo import DialogGen
+from hydit.config import get_args
+from hydit.inference import End2End
+
+
+def inferencer():
+ args = get_args()
+ models_root_path = Path(args.model_root)
+ if not models_root_path.exists():
+ raise ValueError(f"`models_root` not exists: {models_root_path}")
+
+ # Load models
+ gen = End2End(args, models_root_path)
+
+ # Try to enhance prompt
+ if args.enhance:
+ logger.info("Loading DialogGen model (for prompt enhancement)...")
+ enhancer = DialogGen(str(models_root_path / "dialoggen"), args.load_4bit)
+ logger.info("DialogGen model loaded.")
+ else:
+ enhancer = None
+
+ return args, gen, enhancer
+
+
+if __name__ == "__main__":
+ args, gen, enhancer = inferencer()
+
+ if enhancer:
+ logger.info("Prompt Enhancement...")
+ success, enhanced_prompt = enhancer(args.prompt)
+ if not success:
+ logger.info("Sorry, the prompt is not compliant, refuse to draw.")
+ exit()
+ logger.info(f"Enhanced prompt: {enhanced_prompt}")
+ else:
+ enhanced_prompt = None
+
+ # Run inference
+ logger.info("Generating images...")
+ height, width = args.image_size
+ results = gen.predict(args.prompt,
+ height=height,
+ width=width,
+ seed=args.seed,
+ enhanced_prompt=enhanced_prompt,
+ negative_prompt=args.negative,
+ infer_steps=args.infer_steps,
+ guidance_scale=args.cfg_scale,
+ batch_size=args.batch_size,
+ src_size_cond=args.size_cond,
+ )
+ images = results['images']
+
+ # Save images
+ save_dir = Path('results')
+ save_dir.mkdir(exist_ok=True)
+ # Find the first available index
+ all_files = list(save_dir.glob('*.png'))
+ if all_files:
+ start = max([int(f.stem) for f in all_files]) + 1
+ else:
+ start = 0
+
+ for idx, pil_img in enumerate(images):
+ save_path = save_dir / f"{idx + start}.png"
+ pil_img.save(save_path)
+ logger.info(f"Save to {save_path}")
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/test/env_npu.sh b/PyTorch/built-in/mlm/HunyuanDiT/test/env_npu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cd24c60a82fa5e910db8306e559579206d7465d9
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/test/env_npu.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+CANN_INSTALL_PATH_CONF='/etc/Ascend/ascend_cann_install.info'
+
+if [ -f $CANN_INSTALL_PATH_CONF ]; then
+ CANN_INSTALL_PATH=$(cat $CANN_INSTALL_PATH_CONF | grep Install_Path | cut -d "=" -f 2)
+else
+ CANN_INSTALL_PATH="/usr/local/Ascend"
+fi
+
+if [ -d ${CANN_INSTALL_PATH}/ascend-toolkit/latest ]; then
+ source ${CANN_INSTALL_PATH}/ascend-toolkit/set_env.sh
+else
+ source ${CANN_INSTALL_PATH}/nnae/set_env.sh
+fi
+
+count=$(npu-smi info -l | grep -c "NPU ID")
+
+for ((i=0; i<${count}; i=i+1))
+do
+ msnpureport -g error -d ${i}
+done
+
+#将Host日志输出到串口,0-关闭/1-开启
+export ASCEND_SLOG_PRINT_TO_STDOUT=0
+#设置默认日志级别,0-debug/1-info/2-warning/3-error
+export ASCEND_GLOBAL_LOG_LEVEL=3
+#设置Event日志开启标志,0-关闭/1-开启
+export ASCEND_GLOBAL_EVENT_ENABLE=0
+#设置是否开启taskque,0-关闭/1-开启
+export TASK_QUEUE_ENABLE=1
+#设置是否开启combined标志,0-关闭/1-开启
+export COMBINED_ENABLE=1
+#HCCL白名单开关,1-关闭/0-开启
+export HCCL_WHITELIST_DISABLE=1
+export HCCL_IF_IP=$(hostname -I |awk '{print $1}')
+export HCCL_CONNECT_TIMEOUT=1200
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/test/train_full_8p_bf16.sh b/PyTorch/built-in/mlm/HunyuanDiT/test/train_full_8p_bf16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..596ac19052029386566f1a0b10b81191da787a13
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/test/train_full_8p_bf16.sh
@@ -0,0 +1,126 @@
+# 微调生成的ckpt路径
+Network="HunyuanDiT"
+BATCH_SIZE=1
+max_train_steps=5000
+task_flag="dit_g2_full_1024p" # the task flag is used to identify folders.
+resume=./ckpts/t2i/model/ # checkpoint root for resume
+index_file=dataset/porcelain/jsons/porcelain_mt.json # index file for dataloader
+results_dir=./log_EXP # save root for results
+image_size=1024 # training image resolution
+grad_accu_steps=1 # gradient accumulation
+warmup_num_steps=0 # warm-up steps
+lr=0.0001 # learning rate
+ckpt_every=10000 # create a ckpt every a few steps.
+ckpt_latest_every=5000 # create a ckpt named `latest.pt` every a few steps.
+
+export WORLD_SIZE=8
+export MASTER_PORT=29500
+export MASTER_ADDR=127.0.0.1
+
+for para in $*
+do
+ if [[ $para == --batch_size* ]]; then
+ BATCH_SIZE=$(echo ${para#*=})
+ elif [[ $para == --max_train_steps* ]]; then
+ max_train_steps=$(echo ${para#*=})
+ fi
+done
+
+# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径
+cur_path=$(pwd)
+cur_path_last_dirname=${cur_path##*/}
+if [ x"${cur_path_last_dirname}" == x"test" ]; then
+ test_path_dir=${cur_path}
+ cd ..
+ cur_path=$(pwd)
+else
+ test_path_dir=${cur_path}/test
+fi
+
+source ${test_path_dir}/env_npu.sh
+
+#创建DeviceID输出目录,不需要修改
+output_path=${cur_path}/test/output/${ASCEND_DEVICE_ID}
+
+mkdir -p ${output_path}
+
+#训练开始时间,不需要修改
+start_time=$(date +%s)
+echo "start_time: ${start_time}"
+
+model='DiT-g/2'
+params=" \
+ --qk-norm \
+ --model ${model} \
+ --rope-img base512 \
+ --rope-real \
+ "
+deepspeed --num_gpus ${WORLD_SIZE} --num_nodes 1 --master_port=${MASTER_PORT} hydit/train_deepspeed.py ${params} \
+ --task-flag ${task_flag} \
+ --noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.03 \
+ --predict-type v_prediction \
+ --uncond-p 0.44 \
+ --uncond-p-t5 0.44 \
+ --index-file ${index_file} \
+ --random-flip \
+ --lr ${lr} \
+ --batch-size ${BATCH_SIZE} \
+ --image-size ${image_size} \
+ --global-seed 999 \
+ --grad-accu-steps ${grad_accu_steps} \
+ --warmup-num-steps ${warmup_num_steps} \
+ --use-flash-attn \
+ --use-fp16 \
+ --use-ema \
+ --ema-dtype fp32 \
+ --results-dir ${results_dir} \
+ --resume-split \
+ --resume ${resume} \
+ --ckpt-every ${ckpt_every} \
+ --ckpt-latest-every ${ckpt_latest_every} \
+ --log-every 1 \
+ --deepspeed \
+ --deepspeed-optimizer \
+ --use-zero-stage 2 \
+ --multireso \
+ --reso-step 64 \
+ --max-training-steps ${max_train_steps} \
+ > ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${ASCEND_DEVICE_ID}.log 2>&1 &
+
+wait
+
+#训练结束时间,不需要修改
+end_time=$(date +%s)
+e2e_time=$(($end_time - $start_time))
+
+
+# 训练用例信息,不需要修改
+BatchSize=${BATCH_SIZE}
+DeviceType=$(uname -m)
+CaseName=${Network}_bs${BatchSize}_${WORLD_SIZE}'p'_'acc'
+
+# 结果打印,不需要修改
+echo "------------------ Final result ------------------"
+# 输出性能FPS,需要模型审视修改
+FPS=`grep -a 'FPS' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F "FPS " '{print $2}' | tail -100 | awk '{a+=$1} END {if (NR != 0) printf("%.3f",a/NR)}'`
+# 打印,不需要修改
+echo "Final Performance images/sec : $FPS"
+echo "E2E Training Duration sec : $e2e_time"
+
+
+# 性能看护结果汇总
+# 获取性能数据,不需要修改
+# 吞吐量
+ActualFPS=${FPS}
+#单迭代训练时长
+TrainingTime=$(awk 'BEGIN{printf "%.2f\n", '${BATCH_SIZE}'*8/'${FPS}'}')
+
+
+# 关键信息打印到${CaseName}.log中,不需要修改
+echo "Network = ${Network}" >${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "RankSize = ${WORLD_SIZE}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "BatchSize = ${BatchSize}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "DeviceType = ${DeviceType}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "CaseName = ${CaseName}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log
+echo "TrainingTime = ${TrainingTime}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}_perf_report.log
+echo "E2ETrainingTime = ${e2e_time}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}_perf_report.log
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/trt/build_engine.sh b/PyTorch/built-in/mlm/HunyuanDiT/trt/build_engine.sh
new file mode 100644
index 0000000000000000000000000000000000000000..02549cc75558fd184e9a85d5c3ca49d48e00b8ed
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/trt/build_engine.sh
@@ -0,0 +1,70 @@
+# ==============================================================================
+# Description: Export ONNX model and build TensorRT engine.
+# ==============================================================================
+
+# Check if the model root path is exists or provided.
+if [ -z "$1" ]; then
+ if [ -d "ckpts" ]; then
+ echo "The model root directory is not provided. Use the default path 'ckpts'."
+ export MODEL_ROOT=ckpts
+ else
+ echo "Default model path 'ckpts' does not exist. Please provide the path of the model root directory."
+ exit 1
+ fi
+elif [ ! -d "$1" ]; then
+ echo "The model root directory ($1) does not exist."
+ exit 1
+else
+ export MODEL_ROOT=$(cd "$1"; pwd)
+fi
+
+export ONNX_WORKDIR=${MODEL_ROOT}/onnx_model
+echo "MODEL_ROOT=${MODEL_ROOT}"
+echo "ONNX_WORKDIR=${ONNX_WORKDIR}"
+
+# Remove old directories.
+if [ -d "${ONNX_WORKDIR}" ]; then
+ echo "Remove old ONNX directories..."
+ rm -r ${ONNX_WORKDIR}
+fi
+
+# Inspect the project directory.
+SCRIPT_PATH="$( cd "$( dirname "$0" )" && pwd )"
+PROJECT_DIR=$(dirname "$SCRIPT_PATH")
+export PYTHONPATH=${PROJECT_DIR}:${PYTHONPATH}
+echo "PYTHONPATH=${PYTHONPATH}"
+cd ${PROJECT_DIR}
+echo "Change directory to ${PROJECT_DIR}"
+
+# ----------------------------------------
+# 1. Export ONNX model.
+# ----------------------------------------
+
+# Sleep for reading the message.
+sleep 2s
+
+echo "Exporting ONNX model..."
+python trt/export_onnx.py --model-root ${MODEL_ROOT} --onnx-workdir ${ONNX_WORKDIR}
+echo "Exporting ONNX model finished"
+
+# ----------------------------------------
+# 2. Build TensorRT engine.
+# ----------------------------------------
+
+echo "Building TensorRT engine..."
+ENGINE_DIR="${MODEL_ROOT}/t2i/model_trt/engine"
+mkdir -p ${ENGINE_DIR}
+ENGINE_PATH=${ENGINE_DIR}/model_onnx.plan
+PLUGIN_PATH=${MODEL_ROOT}/t2i/model_trt/fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so
+
+trtexec \
+ --onnx=${ONNX_WORKDIR}/export_modified_fmha/model.onnx \
+ --fp16 \
+ --saveEngine=${ENGINE_PATH} \
+ --minShapes=x:2x4x90x90,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:2025x88,sin_cis_img:2025x88 \
+ --optShapes=x:2x4x128x128,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:4096x88,sin_cis_img:4096x88 \
+ --maxShapes=x:2x4x160x160,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:6400x88,sin_cis_img:6400x88 \
+ --shapes=x:2x4x128x128,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:4096x88,sin_cis_img:4096x88 \
+ --verbose \
+ --builderOptimizationLevel=4 \
+ --staticPlugins=${PLUGIN_PATH}
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/trt/export_onnx.py b/PyTorch/built-in/mlm/HunyuanDiT/trt/export_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd2c8f749f55cbcc1f6e61666087bad8406cc2af
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/trt/export_onnx.py
@@ -0,0 +1,246 @@
+from pathlib import Path
+
+import torch
+from loguru import logger
+
+from hydit.config import get_args
+from hydit.modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
+
+import numpy as np
+import onnx
+import onnx_graphsurgeon as gs
+import polygraphy.backend.onnx.loader
+
+
+def _to_tuple(val):
+ if isinstance(val, (list, tuple)):
+ if len(val) == 1:
+ val = [val[0], val[0]]
+ elif len(val) == 2:
+ val = tuple(val)
+ else:
+ raise ValueError(f"Invalid value: {val}")
+ elif isinstance(val, (int, float)):
+ val = (val, val)
+ else:
+ raise ValueError(f"Invalid value: {val}")
+ return val
+
+
+class ExportONNX(object):
+ def __init__(self, args, models_root_path):
+ self.args = args
+ self.model = None
+ # Set device and disable gradient
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ torch.set_grad_enabled(False)
+
+ # Check arguments
+ t2i_root_path = Path(models_root_path) / "t2i"
+ self.root = t2i_root_path
+ logger.info(f"Got text-to-image model root path: {t2i_root_path}")
+
+ # Create folder to save onnx model
+ onnx_workdir = Path(self.args.onnx_workdir)
+ self.onnx_workdir = onnx_workdir
+ self.onnx_export = self.onnx_workdir / "export/model.onnx"
+ self.onnx_export.parent.mkdir(parents=True, exist_ok=True)
+ self.onnx_modify = self.onnx_workdir / "export_modified/model.onnx"
+ self.onnx_modify.parent.mkdir(parents=True, exist_ok=True)
+ self.onnx_fmha = self.onnx_workdir / "export_modified_fmha/model.onnx"
+ self.onnx_fmha.parent.mkdir(parents=True, exist_ok=True)
+
+ def load_model(self):
+ # ========================================================================
+ # Create model structure and load the checkpoint
+ logger.info(f"Building HunYuan-DiT model...")
+ model_config = HUNYUAN_DIT_CONFIG[self.args.model]
+ image_size = _to_tuple(self.args.image_size)
+ latent_size = (image_size[0] // 8, image_size[1] // 8)
+
+ model_dir = self.root / "model"
+ model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt"
+ if not model_path.exists():
+ raise ValueError(f"model_path not exists: {model_path}")
+
+ # Build model structure
+ self.model = HunYuanDiT(self.args,
+ input_size=latent_size,
+ **model_config,
+ log_fn=logger.info,
+ ).half().to(self.device) # Force to use fp16
+ # Load model checkpoint
+ logger.info(f"Loading torch model {model_path}...")
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
+ self.model.load_state_dict(state_dict)
+ self.model.eval()
+ logger.info(f"Loading torch model finished")
+ logger.info("==================================================")
+ logger.info(f" Model is ready. ")
+ logger.info("==================================================")
+
+ def export(self):
+ if self.model is None:
+ self.load_model()
+
+ # Construct model inputs
+ latent_model_input = torch.randn(2, 4, 128, 128, device=self.device).half()
+ t_expand = torch.randint(0, 1000, [2], device=self.device).half()
+ prompt_embeds = torch.randn(2, 77, 1024, device=self.device).half()
+ attention_mask = torch.randint(0, 2, [2, 77], device=self.device).long()
+ prompt_embeds_t5 = torch.randn(2, 256, 2048, device=self.device).half()
+ attention_mask_t5 = torch.randint(0, 2, [2, 256], device=self.device).long()
+ ims = torch.tensor([[1024, 1024, 1024, 1024, 0, 0], [1024, 1024, 1024, 1024, 0, 0]], device=self.device).half()
+ style = torch.tensor([0, 0], device=self.device).long()
+ freqs_cis_img = (
+ torch.randn(4096, 88),
+ torch.randn(4096, 88),
+ )
+
+ save_to = self.onnx_export
+ logger.info(f"Exporting ONNX model {save_to}...")
+ logger.info(f"Exporting ONNX external data {save_to.parent}...")
+ model_args = (
+ latent_model_input,
+ t_expand,
+ prompt_embeds,
+ attention_mask,
+ prompt_embeds_t5,
+ attention_mask_t5,
+ ims, style,
+ freqs_cis_img[0],
+ freqs_cis_img[1]
+ )
+ torch.onnx.export(self.model,
+ model_args,
+ str(save_to),
+ export_params=True,
+ opset_version=17,
+ do_constant_folding=True,
+ input_names=["x", "t", "encoder_hidden_states", "text_embedding_mask",
+ "encoder_hidden_states_t5", "text_embedding_mask_t5", "image_meta_size", "style",
+ "cos_cis_img", "sin_cis_img"],
+ output_names=["output"],
+ dynamic_axes={"x": {0: "2B", 2: "H", 3: "W"}, "t": {0: "2B"},
+ "encoder_hidden_states": {0: "2B"},
+ "text_embedding_mask": {0: "2B"}, "encoder_hidden_states_t5": {0: "2B"},
+ "text_embedding_mask_t5": {0: "2B"},
+ "image_meta_size": {0: "2B"}, "style": {0: "2B"}, "cos_cis_img": {0: "seqlen"},
+ "sin_cis_img": {0: "seqlen"}},
+ )
+ logger.info("Exporting onnx finished")
+
+ def postprocessing(self):
+ load_from = self.onnx_export
+ save_to = self.onnx_modify
+ logger.info(f"Postprocessing ONNX model {load_from}...")
+
+ onnxModel = onnx.load(str(load_from), load_external_data=False)
+ onnx.load_external_data_for_model(onnxModel, str(load_from.parent))
+ graph = gs.import_onnx(onnxModel)
+
+ # ADD GAMMA BETA FOR LN
+ for node in graph.nodes:
+ if node.name == "/final_layer/norm_final/LayerNormalization":
+ constantKernel = gs.Constant("final_layer.norm_final.weight",
+ np.ascontiguousarray(np.ones((1408,), dtype=np.float16)))
+ constantBias = gs.Constant("final_layer.norm_final.bias",
+ np.ascontiguousarray(np.zeros((1408,), dtype=np.float16)))
+ node.inputs = [node.inputs[0], constantKernel, constantBias]
+
+ graph.cleanup().toposort()
+ onnx.save(gs.export_onnx(graph.cleanup()),
+ str(save_to),
+ save_as_external_data=True,
+ all_tensors_to_one_file=False,
+ location=str(save_to.parent),
+ )
+ logger.info(f"Postprocessing ONNX model finished: {save_to}")
+
+ def fuse_attn(self):
+ load_from = self.onnx_modify
+ save_to = self.onnx_fmha
+ logger.info(f"FuseAttn ONNX model {load_from}...")
+
+ onnx_graph = polygraphy.backend.onnx.loader.fold_constants(
+ onnx.load(str(load_from)),
+ allow_onnxruntime_shape_inference=True,
+ )
+ graph = gs.import_onnx(onnx_graph)
+
+ cnt = 0
+ for node in graph.nodes:
+
+ if node.op == "Softmax" and node.i().op == "MatMul" and node.o().op == "MatMul" and \
+ node.o().o().op == "Transpose":
+
+ if "pooler" in node.name:
+ continue
+
+ if "attn1" in node.name:
+ matmul_0 = node.i()
+ transpose = matmul_0.i(1, 0)
+ transpose.attrs["perm"] = [0, 2, 1, 3]
+ k = transpose.outputs[0]
+ q = gs.Variable("transpose_0_v_{}".format(cnt), np.dtype(np.float16))
+ transpose_0 = gs.Node("Transpose", "Transpose_0_{}".format(cnt),
+ attrs={"perm": [0, 2, 1, 3]},
+ inputs=[matmul_0.inputs[0]],
+ outputs=[q])
+ graph.nodes.append(transpose_0)
+
+ matmul_1 = node.o()
+ v = gs.Variable("transpose_1_v_{}".format(cnt), np.dtype(np.float16))
+ transpose_1 = gs.Node("Transpose", "Transpose_1_{}".format(cnt),
+ attrs={"perm": [0, 2, 1, 3]},
+ inputs=[matmul_1.inputs[1]],
+ outputs=[v])
+ graph.nodes.append(transpose_1)
+
+ output_variable = node.o().o().outputs[0]
+ # fMHA_v = gs.Variable("fMHA_v", np.dtype(np.float16))
+ fMHA = gs.Node("fMHAPlugin", "fMHAPlugin_1_{}".format(cnt),
+ # attrs={"scale": 1.0},
+ inputs=[q, k, v],
+ outputs=[output_variable])
+ graph.nodes.append(fMHA)
+ node.o().o().outputs = []
+ cnt = cnt + 1
+
+ elif "attn2" in node.name:
+ matmul_0 = node.i()
+ transpose_q = matmul_0.i()
+ transpose_k = matmul_0.i(1, 0)
+ matmul_1 = node.o()
+ transpose_v = matmul_1.i(1, 0)
+ q = transpose_q.inputs[0]
+ k = transpose_k.inputs[0]
+ v = transpose_v.inputs[0]
+ output_variable = node.o().o().outputs[0]
+ fMHA = gs.Node("fMHAPlugin", "fMHAPlugin_1_{}".format(cnt),
+ # attrs={"scale": 1.0},
+ inputs=[q, k, v],
+ outputs=[output_variable])
+ graph.nodes.append(fMHA)
+ node.o().o().outputs = []
+ cnt = cnt + 1
+
+ logger.info("mha count: ", cnt)
+
+ onnx.save(gs.export_onnx(graph.cleanup()),
+ str(save_to),
+ save_as_external_data=True,
+ )
+ logger.info(f"FuseAttn ONNX model finished: {save_to}")
+
+
+if __name__ == "__main__":
+ args = get_args()
+ models_root_path = Path(args.model_root)
+ if not models_root_path.exists():
+ raise ValueError(f"`models_root` not exists: {models_root_path}")
+
+ exporter = ExportONNX(args, models_root_path)
+ exporter.export()
+ exporter.postprocessing()
+ exporter.fuse_attn()
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/trt/install.sh b/PyTorch/built-in/mlm/HunyuanDiT/trt/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0de9d8fe7876984ef609e71e95f7e2555ed15c78
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/trt/install.sh
@@ -0,0 +1,80 @@
+# ==============================================================================
+# Description: Install TensorRT and prepare the environment for TensorRT.
+# ==============================================================================
+
+# ----------------------------------------
+# Check the system, tools and arguments
+# ----------------------------------------
+
+# Check system. Only support TensorRT on Linux (MacOS is not supported.)
+if [ "$(uname)" != "Linux" ]; then
+ echo "Only support TensorRT on Linux"
+ exit 1
+fi
+
+# Check if the model_trt path is provided. If not, use the default path.
+if [ -z "$1" ]; then
+ MODEL_TRT_DIR=$(cd ckpts/t2i/model_trt; pwd)
+else
+ MODEL_TRT_DIR=$(cd "$1"; pwd)
+fi
+
+# Check if the model_trt path exists.
+if [ ! -d "${MODEL_TRT_DIR}" ]; then
+ echo "The model_trt directory (${MODEL_TRT_DIR}) does not exist. Please specify the path by:"
+ echo " sh trt/install.sh "
+ exit 1
+fi
+
+# Check if ldconfig exists.
+if [ ! -x "$(command -v ldconfig)" ]; then
+ echo "ldconfig is not installed. Please install it first."
+ exit 1
+fi
+
+export TENSORRT_VERSION='9.2.0.5'
+TENSORRT_PACKAGE="${MODEL_TRT_DIR}/TensorRT-${TENSORRT_VERSION}.tar.gz"
+
+# Check if the TensorRT package is downloaded.
+if [ ! -f "${TENSORRT_PACKAGE}" ]; then
+ echo "The TensorRT package (${TENSORRT_PACKAGE}) does not exist. Please download it first with following steps:"
+ echo "1. cd HunyuanDiT"
+ echo "2. huggingface-cli download Tencent-Hunyuan/HunyuanDiT-TensorRT --local-dir ./ckpts/t2i/model_trt"
+ exit 1
+else
+ echo "Found TensorRT package: ${TENSORRT_PACKAGE}"
+fi
+
+# ----------------------------------------
+# Start to install TensorRT
+# ----------------------------------------
+
+# Extract the TensorRT package.
+echo "Extracting the TensorRT package..."
+tar xf "${TENSORRT_PACKAGE}" -C "${MODEL_TRT_DIR}"
+TENSORRT_DIR="${MODEL_TRT_DIR}/TensorRT-${TENSORRT_VERSION}"
+echo "Extracting the TensorRT package finished"
+
+# Add the TensorRT library path to the system library path.
+echo "${MODEL_TRT_DIR}/lib/" >> /etc/ld.so.conf.d/nvidia.conf && ldconfig
+
+# Install the TensorRT Python wheel.
+echo "Installing the TensorRT Python wheel..."
+# Get python version, e.g., cp38 for Python 3.8; cp310 for Python 3.10
+PYTHON_VERSION=$(python -c 'import sys; print(f"cp{sys.version_info.major}{sys.version_info.minor}")')
+python -m pip install --no-cache-dir ${TENSORRT_DIR}/python/tensorrt*-${PYTHON_VERSION}*
+echo "Installing the TensorRT Python wheel finished"
+
+# Prepare activate.sh and deactivate.sh
+{
+ echo "TENSORRT_DIR=${TENSORRT_DIR}"
+ echo 'export LD_LIBRARY_PATH=${TENSORRT_DIR}/lib/:$LD_LIBRARY_PATH'
+ echo 'export LIBRARY_PATH=${TENSORRT_DIR}/lib/:$LIBRARY_PATH'
+ echo 'export PATH=${TENSORRT_DIR}/bin/:$PATH'
+} > $(dirname "$0")/activate.sh
+{
+ echo "TENSORRT_DIR=${TENSORRT_DIR}"
+ echo 'export LD_LIBRARY_PATH=${LD_LIBRARY_PATH/${TENSORRT_DIR}\/lib\/:}'
+ echo 'export LIBRARY_PATH=${LIBRARY_PATH/${TENSORRT_DIR}\/lib\/:}'
+ echo 'export PATH=${PATH/${TENSORRT_DIR}\/bin\/:}'
+} > $(dirname "$0")/deactivate.sh
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/utils/__init__.py b/PyTorch/built-in/mlm/HunyuanDiT/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/utils/collect_env.py b/PyTorch/built-in/mlm/HunyuanDiT/utils/collect_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..53a71cdcf70b01fbf7bfa6e3def1bf512e0b90bb
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/utils/collect_env.py
@@ -0,0 +1,202 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""This file holding some environment constant for sharing by other files."""
+import os
+import os.path as osp
+import subprocess
+import sys
+from collections import OrderedDict, defaultdict
+
+import numpy as np
+import torch
+
+
+def is_rocm_pytorch() -> bool:
+ """Check whether the PyTorch is compiled on ROCm."""
+ is_rocm = False
+ if TORCH_VERSION != 'parrots':
+ try:
+ from torch.utils.cpp_extension import ROCM_HOME
+ is_rocm = True if ((torch.version.hip is not None) and
+ (ROCM_HOME is not None)) else False
+ except ImportError:
+ pass
+ return is_rocm
+
+TORCH_VERSION = torch.__version__
+
+def get_build_config():
+ """Obtain the build information of PyTorch or Parrots."""
+ if TORCH_VERSION == 'parrots':
+ from parrots.config import get_build_info
+ return get_build_info()
+ else:
+ return torch.__config__.show()
+
+try:
+ import torch_musa # noqa: F401
+ IS_MUSA_AVAILABLE = True
+except Exception:
+ IS_MUSA_AVAILABLE = False
+
+def is_musa_available() -> bool:
+ return IS_MUSA_AVAILABLE
+
+def is_cuda_available() -> bool:
+ """Returns True if cuda devices exist."""
+ return torch.cuda.is_available()
+
+def _get_cuda_home():
+ if TORCH_VERSION == 'parrots':
+ from parrots.utils.build_extension import CUDA_HOME
+ else:
+ if is_rocm_pytorch():
+ from torch.utils.cpp_extension import ROCM_HOME
+ CUDA_HOME = ROCM_HOME
+ else:
+ from torch.utils.cpp_extension import CUDA_HOME
+ return CUDA_HOME
+
+
+def _get_musa_home():
+ return os.environ.get('MUSA_HOME')
+
+
+def collect_env():
+ """Collect the information of the running environments.
+
+ Returns:
+ dict: The environment information. The following fields are contained.
+
+ - sys.platform: The variable of ``sys.platform``.
+ - Python: Python version.
+ - CUDA available: Bool, indicating if CUDA is available.
+ - GPU devices: Device type of each GPU.
+ - CUDA_HOME (optional): The env var ``CUDA_HOME``.
+ - NVCC (optional): NVCC version.
+ - GCC: GCC version, "n/a" if GCC is not installed.
+ - MSVC: Microsoft Virtual C++ Compiler version, Windows only.
+ - PyTorch: PyTorch version.
+ - PyTorch compiling details: The output of \
+ ``torch.__config__.show()``.
+ - TorchVision (optional): TorchVision version.
+ - OpenCV (optional): OpenCV version.
+ """
+ from distutils import errors
+
+ env_info = OrderedDict()
+ env_info['sys.platform'] = sys.platform
+ env_info['Python'] = sys.version.replace('\n', '')
+
+ cuda_available = is_cuda_available()
+ musa_available = is_musa_available()
+ env_info['CUDA available'] = cuda_available
+ env_info['MUSA available'] = musa_available
+ env_info['numpy_random_seed'] = np.random.get_state()[1][0]
+
+ if cuda_available:
+ devices = defaultdict(list)
+ for k in range(torch.cuda.device_count()):
+ devices[torch.cuda.get_device_name(k)].append(str(k))
+ for name, device_ids in devices.items():
+ env_info['GPU ' + ','.join(device_ids)] = name
+
+ CUDA_HOME = _get_cuda_home()
+ env_info['CUDA_HOME'] = CUDA_HOME
+
+ if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
+ if CUDA_HOME == '/opt/rocm':
+ try:
+ nvcc = osp.join(CUDA_HOME, 'hip/bin/hipcc')
+ nvcc = subprocess.check_output(
+ f'"{nvcc}" --version', shell=True)
+ nvcc = nvcc.decode('utf-8').strip()
+ release = nvcc.rfind('HIP version:')
+ build = nvcc.rfind('')
+ nvcc = nvcc[release:build].strip()
+ except subprocess.SubprocessError:
+ nvcc = 'Not Available'
+ else:
+ try:
+ nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
+ nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True)
+ nvcc = nvcc.decode('utf-8').strip()
+ release = nvcc.rfind('Cuda compilation tools')
+ build = nvcc.rfind('Build ')
+ nvcc = nvcc[release:build].strip()
+ except subprocess.SubprocessError:
+ nvcc = 'Not Available'
+ env_info['NVCC'] = nvcc
+ elif musa_available:
+ devices = defaultdict(list)
+ for k in range(torch.musa.device_count()):
+ devices[torch.musa.get_device_name(k)].append(str(k))
+ for name, device_ids in devices.items():
+ env_info['GPU ' + ','.join(device_ids)] = name
+
+ MUSA_HOME = _get_musa_home()
+ env_info['MUSA_HOME'] = MUSA_HOME
+
+ if MUSA_HOME is not None and osp.isdir(MUSA_HOME):
+ try:
+ mcc = osp.join(MUSA_HOME, 'bin/mcc')
+ subprocess.check_output(f'"{mcc}" -v', shell=True)
+ except subprocess.SubprocessError:
+ mcc = 'Not Available'
+ env_info['mcc'] = mcc
+ try:
+ # Check C++ Compiler.
+ # For Unix-like, sysconfig has 'CC' variable like 'gcc -pthread ...',
+ # indicating the compiler used, we use this to get the compiler name
+ import io
+ import sysconfig
+ cc = sysconfig.get_config_var('CC')
+ if cc:
+ cc = osp.basename(cc.split()[0])
+ cc_info = subprocess.check_output(f'{cc} --version', shell=True)
+ env_info['GCC'] = cc_info.decode('utf-8').partition(
+ '\n')[0].strip()
+ else:
+ # on Windows, cl.exe is not in PATH. We need to find the path.
+ # distutils.ccompiler.new_compiler() returns a msvccompiler
+ # object and after initialization, path to cl.exe is found.
+ import locale
+ import os
+ from distutils.ccompiler import new_compiler
+ ccompiler = new_compiler()
+ ccompiler.initialize()
+ cc = subprocess.check_output(
+ f'{ccompiler.cc}', stderr=subprocess.STDOUT, shell=True)
+ encoding = os.device_encoding(
+ sys.stdout.fileno()) or locale.getpreferredencoding()
+ env_info['MSVC'] = cc.decode(encoding).partition('\n')[0].strip()
+ env_info['GCC'] = 'n/a'
+ except (subprocess.CalledProcessError, errors.DistutilsPlatformError):
+ env_info['GCC'] = 'n/a'
+ except io.UnsupportedOperation as e:
+ # JupyterLab on Windows changes sys.stdout, which has no `fileno` attr
+ # Refer to: https://github.com/open-mmlab/mmengine/issues/931
+ # TODO: find a solution to get compiler info in Windows JupyterLab,
+ # while preserving backward-compatibility in other systems.
+ env_info['MSVC'] = f'n/a, reason: {str(e)}'
+
+ env_info['PyTorch'] = torch.__version__
+ env_info['PyTorch compiling details'] = get_build_config()
+
+ try:
+ import torchvision
+ env_info['TorchVision'] = torchvision.__version__
+ except ModuleNotFoundError:
+ pass
+
+ try:
+ import cv2
+ env_info['OpenCV'] = cv2.__version__
+ except ImportError:
+ pass
+
+
+ return env_info
+
+if __name__ == '__main__':
+ for name, val in collect_env().items():
+ print(f'{name}: {val}')
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/HunyuanDiT/utils/npu_utils.py b/PyTorch/built-in/mlm/HunyuanDiT/utils/npu_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b8da7b9c077b5eed8654491cf6162a4e087b52c
--- /dev/null
+++ b/PyTorch/built-in/mlm/HunyuanDiT/utils/npu_utils.py
@@ -0,0 +1,76 @@
+import os
+import random
+from typing import List, Optional
+import importlib
+from functools import lru_cache
+
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.optim.optimizer import Optimizer
+
+
+AUTOCAST_MAPPING = {
+ "bf16": torch.bfloat16,
+ "fp16": torch.float16,
+ "fp32": torch.float32,
+}
+
+@lru_cache
+def is_npu_available():
+ if importlib.util.find_spec("torch") is None or importlib.util.find_spec('torch_npu') is None:
+ return False
+
+ import torch
+ import torch_npu
+
+ try:
+ _ = torch.npu.device_count()
+ return torch.npu.is_available()
+ except RuntimeError:
+ return False
+
+
+if is_npu_available():
+ import torch_npu
+
+
+def seed_all(is_gpu=True, seed=1234, mode=False):
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.use_deterministic_algorithms(mode)
+ if is_gpu:
+ torch.cuda.manual_seed_all(seed)
+ torch.cuda.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.enable = False
+ torch.backends.cudnn.benchmark = False
+ else:
+ torch_npu.npu.manual_seed_all(seed)
+ torch_npu.npu.manual_seed(seed)
+
+
+class FlashAttention(torch.nn.Module):
+ def __init__(self, attention_dropout=0):
+ super().__init__()
+ self.attention_dropout = attention_dropout
+
+ def forward(self, query, key, value):
+ heads = query.shape[2]
+ attention_mask = None
+ output = torch_npu.npu_fusion_attention(
+ query, key, value, heads, input_layout='BSND',
+ pse=None,
+ atten_mask=attention_mask,
+ scale=1.0 / math.sqrt(query.shape[-1]),
+ pre_tockens=65536,
+ next_tockens=65536,
+ keep_prob=1. - self.attention_dropout,
+ sync=False,
+ inner_precise=0,
+ )[0]
+ return output