diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/README.md b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..528d70cf0a09cc4b8eb2f52a64e046b41ea4082c --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/README.md @@ -0,0 +1,314 @@ +## 一、准备运行环境 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | + +### 1.1 获取CANN&MindIE安装包&环境准备 +- [800I A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +- [Duo卡](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=2&model=17) +- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +### 1.2 CANN安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +### 1.3 MindIE安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +### 1.4 Torch_npu安装 +安装pytorch框架 版本2.1.0 +[安装包下载](https://download.pytorch.org/whl/cpu/torch/) + +使用pip安装 +```shell +# {version}表示软件版本号,{arch}表示CPU架构。 +pip install torch-${version}-cp310-cp310-linux_${arch}.whl +``` +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```shell +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` +## 二、下载本仓库 + +### 2.1 下载到本地 +```shell + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +``` + +## 三、HunyuanDiT使用 + +### 3.1 权重及配置文件说明 +1. text_encoder权重链接: +```shell + https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers/tree/main/text_encoder +``` +2. text_encoder_2权重链接: +```shell + https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers/tree/main/text_encoder_2 +``` +3. tokenizer权重链接: +```shell + https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers/tree/main/tokenizer +``` +4. tokenizer_2权重链接: +```shell + https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers/tree/main/tokenizer_2 +``` +5. transformer权重链接: +```shell + https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2/tree/main/t2i/model +``` +- 修改该权重的config.json +```shell +{ + "architectures": [ + "HunyuanDiT2DModel" + ], + "input_size": [ + null, + null + ], + "patch_size": 2, + "in_channels": 4, + "hidden_size": 1408, + "depth": 40, + "num_heads": 16, + "mlp_ratio": 4.3637, + "text_states_dim": 1024, + "text_states_dim_t5": 2048, + "text_len": 77, + "text_len_t5": 256 +} +``` +6. vae权重链接: +```shell + https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers/tree/main/vae +``` +- 修改该权重的config.json +```shell +{ + "architectures": [ + "AutoencoderKL" + ], + "in_channels": 3, + "out_channels": 3, + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "layers_per_block": 2, + "act_fn": "silu", + "latent_channels": 4, + "norm_num_groups": 32, + "sample_size": 512, + "scaling_factor": 0.13025, + "shift_factor": null, + "latents_mean": null, + "latents_std": null, + "force_upcast": false, + "use_quant_conv": true, + "use_post_quant_conv": true +} +``` +7. scheduler: +- 新增scheduler_config.json配置文件, 内容如下所示: +```shell +{ + "_class_name": "DDPMScheduler", + "_mindiesd_version": "1.0.0", + "steps_offset": 1, + "beta_start": 0.00085, + "beta_end": 0.02, + "num_train_timesteps": 1000 +} +``` +8. 新增model_index.json +将以上步骤下载的权重放在同一目录下, 并新增model_index.json文件, 该文件内容如下所示 +```shell +{ + "_class_name": "HunyuanDiTPipeline", + "_mindiesd_version": "1.0.RC3", + "scheduler": [ + "mindiesd", + "DDPMScheduler" + ], + "text_encoder": [ + "transformers", + "BertModel" + ], + "text_encoder_2": [ + "transformers", + "T5EncoderModel" + ], + "tokenizer": [ + "transformers", + "BertTokenizer" + ], + "tokenizer_2": [ + "transformers", + "T5Tokenizer" + ], + "transformer": [ + "mindiesd", + "HunyuanDiT2DModel" + ], + "vae": [ + "mindiesd", + "AutoencoderKL" + ] +} +``` +9. 各模型的配置文件、权重文件的层级样例如下所示。 +```commandline +|----hunyuandit +| |---- model_index.json +| |---- scheduler +| | |---- scheduler_config.json +| |---- text_encoder +| | |---- config.json +| | |---- 模型权重 +| |---- text_encoder_2 +| | |---- config.json +| | |---- 模型权重 +| |---- tokenizer +| | |---- config.json +| | |---- 模型权重 +| |---- tokenizer_2 +| | |---- config.json +| | |---- 模型权重 +| |---- transformer +| | |---- config.json +| | |---- 模型权重 +| |---- vae +| | |---- config.json +| | |---- 模型权重 +``` + +### 3.2 单卡单prompt功能测试 +设置权重路径 +```shell +path="ckpts/hydit" +``` +执行命令: +```shell +python inference_hydit.py \ + --path ${path} \ + --device_id 0 \ + --prompt "青花瓷风格,一只小狗" \ + --input_size 1024 1024 \ + --seed 42 \ + --infer_steps 25 +``` +参数说明: +- path:权重路径,包含scheduler、text_encoder、text_encoder_2、tokenizer、 tokenizer_2、transformer、vae,七个模型的配置文件及权重。 +- device_id:推理设备ID。 +- prompt:用于图像生成的文字描述提示。 +- input_size:需要生成的图像尺寸。 +- seed:设置随机种子,默认值为42。 +- infer_steps:推理迭代步数。 + +### 3.3 单卡多prompts进行性能/精度测试 +设置权重路径 +```shell +path="ckpts/hydit" +``` +执行命令: +```shell +python inference_hydit.py \ + --path ${path} \ + --device_id 0 \ + --test_acc \ + --prompt_list "prompts/example_prompts.txt" \ + --input_size 1024 1024 \ + --seed 42 \ + --infer_steps 25 +``` +参数说明: +- path:权重路径,包含scheduler、text_encoder、text_encoder_2、tokenizer、 tokenizer_2、transformer、vae,七个模型的配置文件及权重。 +- device_id:推理设备ID。 +- test_acc:使用 --test_acc 开启全量图像生成,用于性能/精度测试。单prompt功能测试时,不开启该参数。 +- prompt_list:用于图像生成的文字描述提示的列表文件路径。 +- input_size:需要生成的图像尺寸。 +- seed:设置随机种子,默认值为42。 +- infer_steps:推理迭代步数。 + +### 3.4 用LoRA进行测试 +设置权重路径 +```shell +path="ckpts/hydit" +``` +LoRA权重链接: +```shell + https://huggingface.co/Tencent-Hunyuan/HYDiT-LoRA/tree/main +``` +设置LoRA权重路径 +```shell +lora_path = 'ckpts/lora' +``` +执行命令: +```shell +python inference_hydit.py \ + --path ${path} \ + --device_id 0 \ + --prompt "青花瓷风格,一只小狗" \ + --input_size 1024 1024 \ + --seed 42 \ + --infer_steps 25 + --use_lora \ + --lora_ckpt ${lora_path} +``` +参数说明: +- path:权重路径,包含scheduler、text_encoder、text_encoder_2、tokenizer、 tokenizer_2、transformer、vae,七个模型的配置文件及权重。 +- device_id:推理设备ID。 +- prompt:用于图像生成的文字描述提示。 +- input_size:需要生成的图像尺寸。 +- seed:设置随机种子,默认值为42。 +- infer_steps:推理迭代步数。 +- use_lora:使用 --use_lora 开启LoRA风格化切换。 +- lora_ckpt:LoRA权重路径。 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd1e7e1224af63cedf46e193d5851d4edcdd29e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/__init__.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .hydit_pipeline import HunyuanDiTPipeline +from .compile_pipe import compile_pipe \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/compile_pipe.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/compile_pipe.py new file mode 100644 index 0000000000000000000000000000000000000000..4b70a081b244b2163b7786e4f0d3193784340a59 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/compile_pipe.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch.nn as nn +from ..utils import is_npu_available + + +def compile_pipe(pipe): + if is_npu_available(): + device = 'npu' + if hasattr(pipe, "text_encoder") and isinstance(pipe.text_encoder, nn.Module): + pipe.text_encoder.to(device) + if hasattr(pipe, "text_encoder_2") and isinstance(pipe.text_encoder_2, nn.Module): + pipe.text_encoder_2.to(device) + if hasattr(pipe, "transformer") and isinstance(pipe.transformer, nn.Module): + pipe.transformer.to(device) + if hasattr(pipe, "vae") and isinstance(pipe.vae, nn.Module): + pipe.vae.to(device) + return pipe + else: + raise RuntimeError("NPU is not available.") \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/hydit_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/hydit_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c4811de3e4642cd1779d436a7af3855e390b2a8f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/hydit_pipeline.py @@ -0,0 +1,386 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Union, Tuple +import logging + +import torch +import torch_npu +from tqdm import tqdm +import numpy as np + +from .pipeline_utils import HunYuanPipeline +from ..layers.embedding import RotaryPositionEmbedding +from ..utils import postprocess_pil, randn_tensor + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +torch_npu.npu.config.allow_internal_format = False + +MAX_PROMPT_LENGTH = 1024 +NEGATIVE_PROMPT = '错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,' +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] +USE_CACHE = True +CACHE_RESERVE = 9 + + +class HunyuanDiTPipeline(HunYuanPipeline): + + def __init__( + self, + scheduler, + text_encoder, + text_encoder_2, + tokenizer, + tokenizer_2, + transformer, + vae, + input_size: Tuple[int, int] = (1024, 1024), + dtype: torch.dtype = torch.float16 + ): + super().__init__() + torch.set_grad_enabled(False) + + self.scheduler = scheduler + self.text_encoder = text_encoder + self.text_encoder_2 = text_encoder_2 + self.tokenizer = tokenizer + self.tokenizer_2 = tokenizer_2 + self.transformer = transformer + self.vae = vae + self.input_size = input_size + self.dtype = dtype + self._check_init_input() + + self.text_encoder.to(self.dtype) + self.text_encoder_2.to(self.dtype) + + self.device = torch.device("npu") + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Set image height and width. + height = self.input_size[0] + width = self.input_size[1] + self.height = int((height // 16) * 16) + self.width = int((width // 16) * 16) + if (self.height, self.width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(self.width, self.height) + self.height = int(height) + self.width = int(width) + logger.warning(f"Reshaped to ({self.height}, {self.width}), Supported shapes are {SUPPORTED_SHAPE}") + + # Create image rotary position embedding + self.rotary_pos_emb = self._get_rotary_pos_emb() + + # Use DiT Cache + self.use_cache = USE_CACHE + if self.use_cache: + self.skip_flag_true = torch.ones([1], dtype=torch.int64).to(self.device) + self.skip_flag_false = torch.zeros([1], dtype=torch.int64).to(self.device) + self.cache_dict = torch.tensor([5, 2, 30, 9], dtype=torch.int64).to(self.device) + + self.cache_interval = self.cache_dict[1] + self.step_contrast = self.cache_dict[3] % 2 + self.reserve = 0 + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + num_inference_steps: int = 100, + seed_generator: torch.Generator = None + ): + # 1. Check inputs. Raise error if not correct + check_call_input(prompt, num_images_per_prompt, num_inference_steps, seed_generator) + + # 2. Define prompt and negative_prompt + if prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = 1 + negative_prompt = NEGATIVE_PROMPT + if prompt is not None and not isinstance(prompt, type(negative_prompt)): + raise ValueError( + f"negative_prompt should be the same type to prompt, " + f"but got {type(negative_prompt)} != {type(prompt)}." + ) + prompt_info = (prompt, negative_prompt, num_images_per_prompt) + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask = \ + self._encode_prompt(prompt_info, embedder_t5=False, batch_size=batch_size) + prompt_embeds_t5, negative_prompt_embeds_t5, attention_mask_t5, uncond_attention_mask_t5 = \ + self._encode_prompt(prompt_info, embedder_t5=True, batch_size=batch_size) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]).half() + attention_mask = torch.cat([uncond_attention_mask, attention_mask]).half() + prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5]).half() + attention_mask_t5 = torch.cat([uncond_attention_mask_t5, attention_mask_t5]).half() + transformer_input = (prompt_embeds, attention_mask, prompt_embeds_t5, attention_mask_t5) + torch.npu.empty_cache() + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps = self.scheduler.timesteps + step = (timesteps, num_inference_steps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + shape = (batch_size * num_images_per_prompt, + num_channels_latents, + self.height // self.vae_scale_factor, + self.width // self.vae_scale_factor) + latents = randn_tensor(shape, generator=seed_generator, device=self.device, dtype=prompt_embeds.dtype) * 1.0 + + # 6. Denoising loop + latents = self._sampling(latents, step, transformer_input, seed_generator) + image = self.vae.decode(latents / self.vae.config.scaling_factor)[0] + image = postprocess_pil(image) + + return (image, None) + + + def _check_init_input(self): + if not isinstance(self.input_size, tuple): + raise ValueError(f"The type of input_size must be tuple, but got {type(self.input_size)}.") + if len(self.input_size) != 2: + raise ValueError(f"The length of input_size must be 2, but got {len(self.input_size)}.") + if self.input_size[0] % 8 != 0 or self.input_size[0] <= 0: + raise ValueError( + f"The height of input_size must be divisible by 8 and greater than 0, but got {self.input_size[0]}.") + if self.input_size[1] % 8 != 0 or self.input_size[1] <= 0: + raise ValueError( + f"The width of input_size must be divisible by 8 and greater than 0, but got {self.input_size[1]}.") + if self.dtype != torch.float16 and self.dtype != torch.bfloat16: + raise ValueError(f"The input dtype must be float16 or bfloat16, but got {self.dtype}.") + + + def _get_rotary_pos_emb(self): + grid_height = self.height // 8 // self.transformer.config.patch_size + grid_width = self.width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + head_dim = self.transformer.config.hidden_size // self.transformer.config.num_heads + + rope = RotaryPositionEmbedding(head_dim) + freqs_cis_img = rope.get_2d_rotary_pos_embed(grid_height, grid_width, base_size) + if isinstance(freqs_cis_img, tuple) and len(freqs_cis_img) == 2: + return (freqs_cis_img[0].half().to(self.device), freqs_cis_img[1].half().to(self.device)) + else: + raise ValueError(f"The type of rotary_pos_emb must be tuple and the length must be 2.") + + + def _encode_prompt(self, prompt_info, embedder_t5=False, batch_size=1): + if not embedder_t5: + text_encoder = self.text_encoder + tokenizer = self.tokenizer + max_length = self.tokenizer.model_max_length + else: + text_encoder = self.text_encoder_2 + tokenizer = self.tokenizer_2 + max_length = self.tokenizer_2.model_max_length + + prompt, negative_prompt, num_images_per_prompt = prompt_info + # prompt_embeds + prompt_embeds, attention_mask = self._encode_embeds( + prompt, tokenizer, text_encoder, max_length, num_images_per_prompt) + if text_encoder is not None: + prompt_embeds_dtype = text_encoder.dtype + elif self.transformer is not None: + prompt_embeds_dtype = self.transformer.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=self.device) + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + # negative_prompt_embeds + negative_prompt_embeds, uncond_attention_mask = self._encode_negative_embeds( + negative_prompt, tokenizer, text_encoder, prompt_embeds, num_images_per_prompt) + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=self.device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask + + + def _encode_embeds(self, prompt, tokenizer, text_encoder, max_length, num_images_per_prompt): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + attention_mask = text_inputs.attention_mask.to(self.device) + prompt_embeds = text_encoder( + text_input_ids.to(self.device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + attention_mask = attention_mask.repeat(num_images_per_prompt, 1) + + return prompt_embeds, attention_mask + + + def _encode_negative_embeds(self, negative_prompt, tokenizer, text_encoder, prompt_embeds, num_images_per_prompt): + uncond_tokens: List[str] + if isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_attention_mask = uncond_input.attention_mask.to(self.device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(self.device), + attention_mask=uncond_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + uncond_attention_mask = uncond_attention_mask.repeat(num_images_per_prompt, 1) + + return negative_prompt_embeds, uncond_attention_mask + + + def _sampling(self, latents, step, transformer_input, seed_generator): + + timesteps, num_inference_steps = step + + if self.use_cache: + delta_cache = torch.zeros([2, 3840, 1408], dtype=self.dtype).to(self.device) + self.reserve = max(CACHE_RESERVE, num_inference_steps // 3) + + num_warmup_steps = len(timesteps) - num_inference_steps + with self._progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=latent_model_input.device) + + # if use_fp16 + latent_model_input = latent_model_input.half() + t_expand = t_expand.half() + # predict the noise residual + tensor_input = (latent_model_input, t_expand, transformer_input, self.rotary_pos_emb) + if not self.use_cache: + noise_pred = self.transformer(tensor_input) + else: + cache_params = (self.cache_dict, delta_cache.half()) + inputs = [tensor_input, self.use_cache, cache_params, self.skip_flag_false] + if i < self.reserve: + noise_pred, delta_cache = self.transformer(*inputs) + else: + if i % self.cache_interval == self.step_contrast: + noise_pred, delta_cache = self.transformer(*inputs) + else: + inputs[-1] = self.skip_flag_true + noise_pred, delta_cache = self.transformer(*inputs) + + # if learn_sigma + noise_pred, _ = noise_pred.chunk(2, dim=1) + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + guidance_scale = 6.0 + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, seed_generator) + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps): + progress_bar.update() + + return latents + + + def _progress_bar(self, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError(f"_progress_bar_config should be dict, but is {type(self._progress_bar_config)}.") + + if total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("total has to be defined.") + + +def check_call_input(prompt, num_images_per_prompt, num_inference_steps, seed_generator): + if not isinstance(prompt, str): + raise ValueError("The input prompt type must be strings.") + if len(prompt) == 0 or len(prompt) >= MAX_PROMPT_LENGTH: + raise ValueError( + f"The length of the prompt should be (0, {MAX_PROMPT_LENGTH}), but got {len(prompt)}.") + if not isinstance(num_images_per_prompt, int): + raise ValueError("The input num_images_per_prompt type must be an instance of int.") + if num_images_per_prompt < 0: + raise ValueError( + f"Input num_images_per_prompt should be a non-negative integer, but got {num_images_per_prompt}.") + if not isinstance(num_inference_steps, int): + raise ValueError("The input num_inference_steps type must be an instance of int.") + if num_inference_steps < 0: + raise ValueError( + f"Input num_inference_steps should be a non-negative integer, but got {num_inference_steps}.") + if not isinstance(seed_generator, torch.Generator): + raise ValueError( + f"The type of input seed_generator must be torch.Generator, but got {type(seed_generator)}.") + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/pipeline_utils.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/pipeline_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c37672ee41d7f27747f09082c2c2959356250203 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/hydit/pipeline/pipeline_utils.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import inspect +import importlib + +import torch +from tqdm import tqdm + +from mindiesd import ConfigMixin + + +PIPELINE_CONFIG_NAME = "model_index.json" +SCHEDULER = 'scheduler' +TEXT_ENCODER = 'text_encoder' +TEXT_ENCODER_2 = 'text_encoder_2' +TOKENIZER = 'tokenizer' +TOKENIZER_2 = 'tokenizer_2' +TRANSFORMER = 'transformer' +VAE = 'vae' + +HUNYUAN_DEFAULT_IMAGE_SIZE = (1024, 1024) +HUNYUAN_INPUT_SIZE = "input_size" +HUNYUAN_DEFAULT_DTYPE = torch.float16 +HUNYUAN_DTYPE = "dtype" + + +class HunYuanPipeline(ConfigMixin): + config_name = PIPELINE_CONFIG_NAME + + def __init__(self): + super().__init__() + + @classmethod + def from_pretrained(cls, model_path, **kwargs): + input_size = kwargs.pop(HUNYUAN_INPUT_SIZE, HUNYUAN_DEFAULT_IMAGE_SIZE) + dtype = kwargs.pop(HUNYUAN_DTYPE, HUNYUAN_DEFAULT_DTYPE) + if model_path is None: + raise ValueError("The model_path should not be None.") + init_dict, _ = cls.load_config(model_path, **kwargs) + + init_list = [SCHEDULER, TEXT_ENCODER, TEXT_ENCODER_2, TOKENIZER, TOKENIZER_2, TRANSFORMER, VAE] + pipe_init_dict = {} + + all_parameters = inspect.signature(cls.__init__).parameters + + required_param = {k: v for k, v in all_parameters.items() if v.default == inspect.Parameter.empty} + expected_modules = set(required_param.keys()) - {"self"} + # init the module from kwargs + passed_module = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + + for key in tqdm(init_list, desc="Loading hunyuan-dit-pipeline components"): + if key not in init_dict: + raise ValueError(f"Get {key} from init config failed!") + if key in passed_module: + pipe_init_dict[key] = passed_module.pop(key) + else: + modules, cls_name = init_dict[key] + if modules == "mindiesd": + library = importlib.import_module("hydit") + else: + library = importlib.import_module(modules) + class_obj = getattr(library, cls_name) + + sub_folder = os.path.join(model_path, key) + + if key == TRANSFORMER: + pipe_init_dict[key] = class_obj.from_pretrained( + sub_folder, input_size=input_size, dtype=dtype, **kwargs) + elif key == VAE: + pipe_init_dict[key] = class_obj.from_pretrained(sub_folder, dtype=dtype, **kwargs) + elif key == SCHEDULER: + pipe_init_dict[key] = class_obj.from_config(sub_folder) + else: + pipe_init_dict[key] = class_obj.from_pretrained(sub_folder, **kwargs) + + pipe_init_dict[HUNYUAN_INPUT_SIZE] = input_size + pipe_init_dict[HUNYUAN_DTYPE] = dtype + + return cls(**pipe_init_dict) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/inference_hydit.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/inference_hydit.py new file mode 100644 index 0000000000000000000000000000000000000000..5bea2f2a2dfc7b8f34e4676ece7fc98aae950663 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_dit/inference_hydit.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import random +import argparse +import time +from pathlib import Path +import logging + +import torch + +from hydit import HunyuanDiTPipeline, compile_pipe, set_seeds_generator +from lora import multi_lora + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--path", type=str, default="ckpts/hydit", help="Path to the model directory") + parser.add_argument("--device_id", type=int, default=0, help="NPU device id") + parser.add_argument("--device", type=str, default="npu", help="NPU") + parser.add_argument("--prompt", type=str, default="渔舟唱晚", help="The prompt for generating images") + parser.add_argument("--prompt_list", type=str, default="prompts/example_prompts.txt", help="The prompt list") + parser.add_argument("--test_acc", action="store_true", help="Run or not 'example_prompts.txt'") + parser.add_argument("--input_size", type=int, nargs='+', default=[1024, 1024], help='Image size (h, w)') + parser.add_argument("--type", type=str, default="fp16", help="The torch type is fp16 or bf16") + parser.add_argument("--batch_size", type=int, default=1, help="Per-NPU batch size") + parser.add_argument('--seed', type=int, default=42, help="A seed for all the prompts") + parser.add_argument("--infer_steps", type=int, default=25, help="Inference steps") + parser.add_argument("--use_lora", action="store_true", help="Use LoRA checkpoint") + parser.add_argument("--lora_ckpt", type=str, default="ckpts/lora", help="LoRA checkpoint") + return parser.parse_args() + + +def get_dtype(args): + dtype = torch.bfloat16 + if args.type == 'bf16': + dtype = torch.bfloat16 + elif args.type == 'fp16': + dtype = torch.float16 + else: + logger.error("Not supported.") + return dtype + + +def get_seed(args): + seed = args.seed + if seed is None: + seed = random.randint(0, 1_000_000) + if not isinstance(seed, int): + raise ValueError(f"The type of seed must be int, but got {type(seed)}.") + if seed < 0: + raise ValueError(f"Input seed must be a non-negative integer, but got {seed}.") + return set_seeds_generator(seed, device=args.device) + + +def get_prompts(args): + if not args.test_acc: + prompts = args.prompt + prompts = [prompts.strip()] + else: + lines_list = [] + prompt_list_path = os.path.join(args.path, args.prompt_list) + with open(prompt_list_path, 'r') as file: + for line in file: + line = line.strip() + lines_list.append(line) + prompts = lines_list + return prompts + + +def infer(args): + if not Path(args.path).exists(): + raise ValueError(f"args.path not exists: {Path(args.path)}") + if len(args.input_size) != 2: + raise ValueError(f"The length of args.input_size must be 2, but got {len(args.input_size)}") + input_size = (args.input_size[0], args.input_size[1]) + + torch.npu.set_device(args.device_id) + dtype = get_dtype(args) + seed_generator = get_seed(args) + + pipeline = HunyuanDiTPipeline.from_pretrained(model_path=args.path, input_size=input_size, dtype=dtype) + pipeline = compile_pipe(pipeline) + + if args.use_lora: + merge_state_dict = multi_lora(args, pipeline) + pipeline.transformer.load_state_dict(merge_state_dict) + + prompts = get_prompts(args) + loops = len(prompts) + + save_dir = Path('results') + if not save_dir.exists(): + save_dir.mkdir(exist_ok=True) + + now_time = time.localtime(time.time()) + time_dir_name = time.strftime("%m%d%H%M%S", now_time) + time_dir = save_dir / Path(time_dir_name) + time_dir.mkdir(exist_ok=True) + + pipeline_total_time = 0.0 + for i in range(loops): + prompt = prompts[i] + + start_time = time.time() + + result_images = pipeline( + prompt=prompt, + num_images_per_prompt=args.batch_size, + num_inference_steps=args.infer_steps, + seed_generator=seed_generator, + )[0] + + pipeline_time = time.time() - start_time + logger.info("HunyuanDiT No.{%d} time: %.3f", i, pipeline_time) + torch.npu.empty_cache() + + if i >= 2: + pipeline_total_time += pipeline_time + + save_path = time_dir / f"{i}.png" + result_images[0].save(save_path) + torch.npu.empty_cache() + + if args.test_acc: + if loops <= 2: + raise ValueError(f"The loops must be larger than 2 but got {loops}") + pipeline_average_time = pipeline_total_time / (loops - 2) + logger.info("HunyuanDiT pipeline_average_time: %.3f", pipeline_average_time) + torch.npu.empty_cache() + + +if __name__ == "__main__": + inference_args = parse_arguments() + infer(inference_args) \ No newline at end of file