From 1aee19f11b54bcfbfc2cab6645d4f00a7f3875d4 Mon Sep 17 00:00:00 2001 From: huanghao Date: Tue, 3 Sep 2024 09:23:15 +0800 Subject: [PATCH 01/20] =?UTF-8?q?=E8=A7=A3=E5=86=B3SD3=20DUO=E5=8D=A1?= =?UTF-8?q?=E5=9C=BA=E6=99=AF=E5=A4=9Abatch=E6=8E=A8=E7=90=86=E5=A4=B1?= =?UTF-8?q?=E8=B4=A5=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../foundation/stable_diffusion_3/README.md | 65 ++++++++++++++++--- .../stable_diffusion_3/background_runtime.py | 2 +- .../stable_diffusion_3/export_model.py | 4 +- .../stable_diffusion3_pipeline.py | 5 +- 4 files changed, 60 insertions(+), 16 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md index 25f45779a7..4a807769b9 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md @@ -100,8 +100,8 @@ git clone https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers ``` - 1. 导出pt模型并进行编译。(可选) - + 1. 导出pt模型并进行编译。 + (1) 设置模型权重的路径 ```bash # sd3 (执行时下载权重) model_base="stabilityai/stable-diffusion-3-medium-diffusers" @@ -109,8 +109,12 @@ # sd3 (使用上一步下载的权重) model_base="./stable-diffusion-3-medium-diffusers" ``` + (2) 创建文件夹./models存放导出的模型 + ```bash + mkdir ./models + ``` - 执行命令: + (3) 执行export命令 ```bash # 800I A2,非并行 @@ -126,6 +130,7 @@ - --batch_size: 设置batch_size, 默认值为1, 当前仅支持batch_size=1的场景 - --soc:只支持Ascend910B4和Ascend310P3。默认为Ascend910B4。 - --device:推理设备ID + 注意:trace+compile耗时较长且占用较多的CPU资源,请勿在执行export命令时运行其他占用CPU内存的任务,避免程序意外退出。 2. 开始推理验证。 @@ -200,6 +205,7 @@ 非并行策略,执行完成后在`./results`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 并行策略,同时使用双卡并行策略,执行完成后在`./results_parallel`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + 注意:当前MindIE-Torch和torch_npu的synchronizing stream不兼容,为避免出错,建议在运行推理前先卸载torch_npu。 ## 精度验证 @@ -209,11 +215,13 @@ 注意,由于要生成的图片数量较多,进行完整的精度验证需要耗费很长的时间。 - 1. 下载Parti数据集 + 1. 下载Parti数据集和hpsv2数据集 ```bash + # 下载Parti数据集 wget https://raw.githubusercontent.com/google-research/parti/main/PartiPrompts.tsv --no-check-certificate ``` + hpsv2数据集下载链接:https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/hpsv2_benchmark_prompts.json 2. 下载模型权重 @@ -279,7 +287,44 @@ 不使用并行策略,执行完成后在`./results_PartiPrompts`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 使用双卡并行策略,执行完成后在`./results_PartiPrompts_parallel`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 - 4. 计算精度指标 +4. 使用推理脚本读取hpsv2数据集,生成图片 + + ```bash + # 不使用并行 + python3 stable_diffusion3_pipeline.py \ + --model ${model_base} \ + --prompt_file_type hpsv2 \ + --num_images_per_prompt 1 \ + --info_file_save_path ./image_info_hpsv2.json \ + --device 0 \ + --save_dir ./results_hpsv2 \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + + # 使用双卡并行策略 + python3 stable_diffusion3_pipeline.py \ + --model ${model_base} \ + --prompt_file_type hpsv2 \ + --num_images_per_prompt 1 \ + --info_file_save_path ./image_info_hpsv2.json \ + --device 0,1 \ + --save_dir ./results_hpsv2_parallel \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + ``` + 参数说明: + - --info_file_save_path:生成图片信息的json文件路径。 + + 不使用并行策略,执行完成后在`./results_hpsv2`目录下生成推理图片,在当前目录生成一个`image_info_hpsv2.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 + 使用双卡并行策略,执行完成后在`./results_hpsv2_parallel`目录下生成推理图片,在当前目录生成一个`image_info_hpsv2.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + +5. 计算精度指标 1. CLIP-score ```bash python3 clip_score.py \ @@ -290,27 +335,27 @@ ``` 参数说明: - - --device: 推理设备。 + - --device: 推理设备,默认为"cpu",如果是cuda设备可设置为"cuda"。 - --image_info: 上一步生成的`image_info.json`文件。 - --model_name: Clip模型名称。 - --model_weights_path: Clip模型权重文件路径。 - 执行完成后会在屏幕打印出精度计算结果。 + clip_score.py脚本可参考[SDXL](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/clip_score.py),执行完成后会在屏幕打印出精度计算结果。 2. HPSv2 ```bash python3 hpsv2_score.py \ - --image_info="image_info.json" \ + --image_info="image_info_hpsv2.json" \ --HPSv2_checkpoint="./HPS_v2_compressed.pt" \ --clip_checkpoint="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" ``` 参数说明: - - --image_info: 上一步生成的`image_info.json`文件。 + - --image_info: 上一步生成的`image_info_hpsv2.json`文件。 - --HPSv2_checkpoint: HPSv2模型权重文件路径。 - --clip_checkpointh: Clip模型权重文件路径。 - 执行完成后会在屏幕打印出精度计算结果。 + hpsv2_score.py脚本可参考[SDXL](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/hpsv2_score.py),执行完成后会在屏幕打印出精度计算结果。 # 模型推理性能&精度 diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/background_runtime.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/background_runtime.py index 6f4935af2d..482d82309c 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/background_runtime.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/background_runtime.py @@ -177,7 +177,7 @@ class BackgroundRuntime: for i, _ in enumerate(output_arrays): output = output_cpu.numpy() - output_arrays[i][:] = output[i][:] + output_arrays[i][:] = output[:] infer_num += 1 sync_pipe.send('') diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py index 35dcf2e898..1cda87042f 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py @@ -195,7 +195,7 @@ def export_dit(sd_pipeline, args): [batch_size, max_position_embeddings, encoder_hidden_size * 2], dtype=torch.float32 ), torch.ones([batch_size, encoder_hidden_size], dtype=torch.float32), - torch.ones([batch_size], dtype=torch.int64) + torch.ones([1], dtype=torch.int64) ) dit = DiTExport(dit_model).eval() torch.jit.trace(dit, dummy_input).save(dit_pt_path) @@ -212,7 +212,7 @@ def export_dit(sd_pipeline, args): dtype=mindietorch.dtype.FLOAT), mindietorch.Input((batch_size, encoder_hidden_size), dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((batch_size,), dtype=mindietorch.dtype.INT64)] + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64)] compile_dit(model, inputs, dit_compiled_path, args.soc) else: logging.info("dit_compiled_path already exists.") diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py index e5cea8855a..5415517f7e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py @@ -190,7 +190,7 @@ class AIEStableDiffusion3Pipeline(StableDiffusion3Pipeline): (batch_size, in_channels, sample_size, sample_size), (batch_size, max_position_embeddings, encoder_hidden_size * 2), (batch_size, encoder_hidden_size), - (batch_size,), + (1,), ], input_dtypes=[np.float32, np.float32, np.float32, np.int64], output_shapes=[(batch_size, in_channels, sample_size, sample_size)], @@ -534,8 +534,7 @@ class AIEStableDiffusion3Pipeline(StableDiffusion3Pipeline): if not self.use_parallel_inferencing and self.do_classifier_free_guidance: latent_model_input = torch.cat([latents] * 2) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]).to(torch.int64) - timestep_npu = timestep.to(f"npu:{self.device_0}") + timestep = t.to(torch.int64)[None].to(f"npu:{self.device_0}") else: latent_model_input = latents timestep = t.to(torch.int64) -- Gitee From fedc6b328069afba69898e3fd83871495ae15de1 Mon Sep 17 00:00:00 2001 From: huanghao Date: Thu, 5 Sep 2024 10:21:01 +0800 Subject: [PATCH 02/20] =?UTF-8?q?DITCache=E5=92=8CTODO=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../foundation/stable_diffusion_3/README.md | 2 +- .../stable_diffusion_3/attention.patch | 41 +++ .../attention_processor_patch.py | 32 ++ .../stable_diffusion_3/compile_model.py | 9 + .../stable_diffusion_3/export_model.py | 4 + .../stable_diffusion3_pipeline.py | 6 + .../stable_diffusion3_pipeline_cache.py | 317 ++++++++++++++++++ .../stable_diffusion_3/transformer_sd3.patch | 41 +++ .../transformer_sd3_patch.py | 32 ++ 9 files changed, 483 insertions(+), 1 deletion(-) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md index 4a807769b9..c29b64bdff 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md @@ -80,7 +80,7 @@ ``` - 执行命令: ```bash - python3 attention_patch.py + python3 attention_processor_patch.py ``` ## 模型推理 diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch new file mode 100644 index 0000000000..dc22a411ea --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch @@ -0,0 +1,41 @@ +--- attention_processor.py 2024-08-14 12:41:10.528000000 +0000 ++++ attention_processor.py 2024-08-14 12:42:25.620000000 +0000 +@@ -1082,6 +1082,29 @@ + return hidden_states + + ++def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, ++ scale=None) -> torch.Tensor: ++ # Efficient implementation equivalent to the following: ++ L, S = query.size(-2), key.size(-2) ++ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale ++ attn_bias = torch.zeros(L, S, dtype=query.dtype) ++ if is_causal: ++ assert attn_mask is None ++ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) ++ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) ++ attn_bias.to(query.dtype) ++ ++ if attn_mask is not None: ++ if attn_mask.dtype == torch.bool: ++ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) ++ else: ++ attn_bias += attn_mask ++ attn_weight = query @ key.transpose(-2, -1) * scale_factor ++ attn_weight = torch.softmax(attn_weight, dim=-1) ++ attn_weight = torch.dropout(attn_weight, dropout_p, train=True) ++ return attn_weight @ value ++ ++ + class JointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + +@@ -1132,7 +1155,7 @@ + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + +- hidden_states = hidden_states = F.scaled_dot_product_attention( ++ hidden_states = hidden_states = scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py new file mode 100644 index 0000000000..97585e6af7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py @@ -0,0 +1,32 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", + "attention_processor.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main() diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py index b48994f182..ab3e74e749 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py @@ -93,3 +93,12 @@ class DiTExport(torch.nn.Module): def compile_dit(model, inputs, dit_compiled_path, soc_version): dit_param = CompileParam(inputs, soc_version) common_compile(model, dit_compiled_path, dit_param) + + +class Scheduler(torch.nn.Module): + def __init__(self, ): + super().__init__() + self.sigmas = None + + def forward(self): + pass diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py index 1cda87042f..0b0131d9af 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py @@ -254,6 +254,10 @@ def export_vae(sd_pipeline, args): logging.info("vae_compiled_path already exists.") +def export_scheduler(sd_pipeline, args): + pass + + def export(args): pipeline = StableDiffusion3Pipeline.from_pretrained(args.model).to('cpu') export_clip(pipeline, args) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py index 5415517f7e..4968b6832d 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py @@ -725,6 +725,12 @@ def parse_arguments(): type=int, help="image width" ) + parser.add_argument( + "--cache_dict", + default="1,2,20,10", + type=str, + help="steps to use cache data" + ) return parser.parse_args() diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py new file mode 100644 index 0000000000..5cd48ad34d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -0,0 +1,317 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union +import numpy as np +import torch +import mindietorch +from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps +from stable_diffusion3_pipeline import PromptLoader, AIEStableDiffusion3Pipeline, parse_arguments + +clip_time = 0 +t5_time = 0 +dit_time = 0 +vae_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 + + +class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): + def compile_aie_model(self): + pass + + @torch.no_grad() + def forward( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + cache_dict: torch.LongTensor = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + global p1_time, p2_time, p3_time + start = time.time() + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + p1_time += (time.time() - start) + start1 = time.time() + + if self.do_classifier_free_guidance and not self.use_parallel_inferencing: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + else: + prompt_embeds, prompt_embeds_1 = negative_prompt_embeds, prompt_embeds + pooled_prompt_embeds, pooled_prompt_embeds_1 = negative_pooled_prompt_embeds, pooled_prompt_embeds + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + global dit_time + global vae_time + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + if not self.use_parallel_inferencing and self.do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.to(torch.int64)[None].to(f"npu:{self.device_0}") + else: + latent_model_input = latents + timestep = t.to(torch.int64) + self.dit_bg.infer_asyn([ + latent_model_input.numpy(), + prompt_embeds_1.numpy(), + pooled_prompt_embeds_1.numpy(), + timestep[None].numpy().astype(np.int64) + ]) + timestep_npu = timestep[None].to(f"npu:{self.device_0}") + + latent_model_input_npu = latent_model_input.to(f"npu:{self.device_0}") + prompt_embeds_npu = prompt_embeds.to(f"npu:{self.device_0}") + pooled_prompt_embeds_npu = pooled_prompt_embeds.to(f"npu:{self.device_0}") + + start = time.time() + noise_pred = self.compiled_dit_model( + latent_model_input_npu, + prompt_embeds_npu, + pooled_prompt_embeds_npu, + timestep_npu + ).to("cpu") + dit_time += (time.time() - start) + + # perform guidance + if self.do_classifier_free_guidance: + if self.use_parallel_inferencing: + noise_pred_text = torch.from_numpy(self.dit_bg.wait_and_get_outputs()[0]) + else: + noise_pred, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred + self.guidance_scale * (noise_pred_text - noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + p2_time += (time.time() - start1) + start2 = time.time() + + if output_type == "latent": + image = latents + else: + start = time.time() + image = self.compiled_vae_model(latents.to(f"npu:{self.device_0}")).to("cpu") + vae_time += (time.time() - start) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + p3_time += (time.time() - start2) + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + pipe = AIEStableDiffusion3Pipeline.from_pretrained(args.model).to("cpu") + pipe.parser_args(args) + pipe.compile_aie_model() + if isinstance(args.device, list): + mindietorch.set_device(args.device[0]) + else: + mindietorch.set_device(args.device) + + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt, + args.max_num_prompts) + + infer_num = 0 + image_info = [] + current_prompt = None + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + + start_time = time.time() + images = pipe.forward( + prompts, + negative_prompt="", + num_inference_steps=args.steps, + guidance_scale=7.5, + ) + if i > 4: # do not count the time spent inferring the first 0 to 4 images + use_time += time.time() - start_time + + for j in range(n_prompts): + image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + + image_info[-1]['images'].append(image_save_path) + + infer_num = infer_num - 5 # do not count the time spent inferring the first 5 images + print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" + f"average time: {use_time / infer_num:.3f}s\n") + + if hasattr(pipe, 'device_1'): + if (pipe.dit_bg): + pipe.dit_bg.stop() + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + mindietorch.finalize() + + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch new file mode 100644 index 0000000000..dc22a411ea --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch @@ -0,0 +1,41 @@ +--- attention_processor.py 2024-08-14 12:41:10.528000000 +0000 ++++ attention_processor.py 2024-08-14 12:42:25.620000000 +0000 +@@ -1082,6 +1082,29 @@ + return hidden_states + + ++def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, ++ scale=None) -> torch.Tensor: ++ # Efficient implementation equivalent to the following: ++ L, S = query.size(-2), key.size(-2) ++ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale ++ attn_bias = torch.zeros(L, S, dtype=query.dtype) ++ if is_causal: ++ assert attn_mask is None ++ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) ++ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) ++ attn_bias.to(query.dtype) ++ ++ if attn_mask is not None: ++ if attn_mask.dtype == torch.bool: ++ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) ++ else: ++ attn_bias += attn_mask ++ attn_weight = query @ key.transpose(-2, -1) * scale_factor ++ attn_weight = torch.softmax(attn_weight, dim=-1) ++ attn_weight = torch.dropout(attn_weight, dropout_p, train=True) ++ return attn_weight @ value ++ ++ + class JointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + +@@ -1132,7 +1155,7 @@ + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + +- hidden_states = hidden_states = F.scaled_dot_product_attention( ++ hidden_states = hidden_states = scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py new file mode 100644 index 0000000000..97585e6af7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py @@ -0,0 +1,32 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", + "attention_processor.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main() -- Gitee From 697c181e2de8a73ee31ce548b7c124e94a837478 Mon Sep 17 00:00:00 2001 From: huanghao Date: Thu, 5 Sep 2024 10:29:34 +0800 Subject: [PATCH 03/20] =?UTF-8?q?DITCache=E5=92=8CTODO=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../foundation/stable_diffusion_3/transformer_sd3_patch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py index 97585e6af7..88033355c6 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py @@ -17,6 +17,7 @@ import logging import diffusers + def main(): diffusers_path = diffusers.__path__ diffusers_version = diffusers.__version__ -- Gitee From 2302e40d4c9ed34e4802092f5439a7696dddc2c1 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 02:34:18 +0000 Subject: [PATCH 04/20] update MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch. Signed-off-by: huanghao7 --- .../stable_diffusion_3/attention.patch | 58 ++++++------------- 1 file changed, 18 insertions(+), 40 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch index dc22a411ea..ce183bd914 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch @@ -1,41 +1,19 @@ ---- attention_processor.py 2024-08-14 12:41:10.528000000 +0000 -+++ attention_processor.py 2024-08-14 12:42:25.620000000 +0000 -@@ -1082,6 +1082,29 @@ - return hidden_states - - -+def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, -+ scale=None) -> torch.Tensor: -+ # Efficient implementation equivalent to the following: -+ L, S = query.size(-2), key.size(-2) -+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale -+ attn_bias = torch.zeros(L, S, dtype=query.dtype) -+ if is_causal: -+ assert attn_mask is None -+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) -+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) -+ attn_bias.to(query.dtype) -+ -+ if attn_mask is not None: -+ if attn_mask.dtype == torch.bool: -+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) -+ else: -+ attn_bias += attn_mask -+ attn_weight = query @ key.transpose(-2, -1) * scale_factor -+ attn_weight = torch.softmax(attn_weight, dim=-1) -+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True) -+ return attn_weight @ value -+ -+ - class JointAttnProcessor2_0: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - -@@ -1132,7 +1155,7 @@ - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - -- hidden_states = hidden_states = F.scaled_dot_product_attention( -+ hidden_states = hidden_states = scaled_dot_product_attention( - query, key, value, dropout_p=0.0, is_causal=False +--- attention.py 2024-09-04 09:22:31.768000000 +0000 ++++ attention.py 2024-09-04 09:17:12.680000000 +0000 +@@ -100,7 +100,7 @@ + processing of `context` conditions. + """ + +- def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False): ++ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False, layer_idx=0): + super().__init__() + + self.context_pre_only = context_pre_only +@@ -134,6 +134,7 @@ + context_pre_only=context_pre_only, + bias=True, + processor=processor, ++ layer_idx=layer_idx ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) \ No newline at end of file + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) \ No newline at end of file -- Gitee From 1c1c47cab8e6cb9d954236b1ffbbab28b79d3498 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 02:36:01 +0000 Subject: [PATCH 05/20] update MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch. Signed-off-by: huanghao7 --- .../stable_diffusion_3/transformer_sd3.patch | 156 +++++++++++++----- 1 file changed, 119 insertions(+), 37 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch index dc22a411ea..6b18fd8831 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch @@ -1,41 +1,123 @@ ---- attention_processor.py 2024-08-14 12:41:10.528000000 +0000 -+++ attention_processor.py 2024-08-14 12:42:25.620000000 +0000 -@@ -1082,6 +1082,29 @@ - return hidden_states - - -+def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, -+ scale=None) -> torch.Tensor: -+ # Efficient implementation equivalent to the following: -+ L, S = query.size(-2), key.size(-2) -+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale -+ attn_bias = torch.zeros(L, S, dtype=query.dtype) -+ if is_causal: -+ assert attn_mask is None -+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) -+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) -+ attn_bias.to(query.dtype) +--- transformer_sd3.py 2024-09-04 09:21:58.280000000 +0000 ++++ transformer_sd3.py 2024-09-04 10:01:47.196000000 +0000 +@@ -97,6 +97,7 @@ + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.inner_dim, + context_pre_only=i == num_layers - 1, ++ layer_idx=i + ) + for i in range(self.config.num_layers) + ] +@@ -106,6 +107,8 @@ + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False ++ self.delta_cache = None ++ self.delta_cache_hidden = None + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: +@@ -245,9 +248,14 @@ + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, ++ cache_dict: torch.LongTensor = None, ++ if_skip: int = 0, ++ delta_cache: torch.FloatTensor = None, ++ delta_cache_hidden: torch.FloatTensor = None, ++ use_cache: bool = False, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ++ ): + """ + The [`SD3Transformer2DModel`] forward method. + +@@ -281,10 +289,6 @@ + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) +- else: +- logger.warning( +- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." +- ) + + height, width = hidden_states.shape[-2:] + +@@ -292,9 +296,8 @@ + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + +- for block in self.transformer_blocks: +- if self.training and self.gradient_checkpointing: +- ++ if self.training and self.gradient_checkpointing: ++ for block in self.transformer_blocks: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: +@@ -312,11 +315,14 @@ + temb, + **ckpt_kwargs, + ) +- +- else: +- encoder_hidden_states, hidden_states = block( +- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb +- ) ++ else: ++ ( ++ (encoder_hidden_states, hidden_states), ++ delta_cache, ++ delta_cache_hidden ++ ) = self.forward_blocks(hidden_states, encoder_hidden_states, temb, ++ use_cache, if_skip, cache_dict, delta_cache, ++ delta_cache_hidden) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) +@@ -339,6 +345,43 @@ + unscale_lora_layers(self, lora_scale) + + if not return_dict: +- return (output,) ++ return (output, delta_cache, delta_cache_hidden) + + return Transformer2DModelOutput(sample=output) ++ ++ def forward_blocks_range(self, hidden_states, encoder_hidden_states, temb, start_idx, end_idx): ++ for block_idx, block in enumerate(self.transformer_blocks[start_idx: end_idx]): ++ encoder_hidden_states, hidden_states = block( ++ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ++ ) ++ # self.x_out.append([torch.mean(x).to('cpu').numpy(), torch.var(x).to('cpu').numpy()]) ++ return hidden_states, encoder_hidden_states + -+ if attn_mask is not None: -+ if attn_mask.dtype == torch.bool: -+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) ++ def forward_blocks(self, hidden_states, encoder_hidden_states, temb, use_cache, if_skip, cache_dict, delta_cache, ++ delta_cache_hidden): ++ if not use_cache: ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ 0, len(self.transformer_blocks)) + else: -+ attn_bias += attn_mask -+ attn_weight = query @ key.transpose(-2, -1) * scale_factor -+ attn_weight = torch.softmax(attn_weight, dim=-1) -+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True) -+ return attn_weight @ value ++ # infer [0, cache_start) ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ 0, cache_dict[0]) + ++ # infer [cache_start, cache_end) ++ cache_end = cache_dict[0] + cache_dict[2] ++ hidden_states_before_cache = hidden_states.clone() ++ encoder_hidden_states_before_cache = encoder_hidden_states.clone() ++ if not if_skip: ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, ++ temb, cache_dict[0], ++ cache_end) ++ delta_cache = hidden_states - hidden_states_before_cache ++ delta_cache_hidden = encoder_hidden_states - encoder_hidden_states_before_cache ++ else: ++ hidden_states = hidden_states_before_cache + delta_cache ++ encoder_hidden_states = encoder_hidden_states_before_cache + delta_cache_hidden + - class JointAttnProcessor2_0: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - -@@ -1132,7 +1155,7 @@ - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - -- hidden_states = hidden_states = F.scaled_dot_product_attention( -+ hidden_states = hidden_states = scaled_dot_product_attention( - query, key, value, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) \ No newline at end of file ++ # infer [cache_end, len(self.blocks)) ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ cache_end, len(self.transformer_blocks)) ++ return (encoder_hidden_states, hidden_states), delta_cache, delta_cache_hidden -- Gitee From 25602876e34bb43740974f4e73897dfb8b786897 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 06:31:12 +0000 Subject: [PATCH 06/20] update stable_diffusion_3/stable_diffusion3_pipeline_cache.py. Signed-off-by: huanghao7 --- .../stable_diffusion3_pipeline_cache.py | 229 ++++++++++++++---- 1 file changed, 176 insertions(+), 53 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py index 5cd48ad34d..d3b4c483ff 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -16,17 +16,18 @@ import json import os import time from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch import mindietorch from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps from stable_diffusion3_pipeline import PromptLoader, AIEStableDiffusion3Pipeline, parse_arguments +tgate = 20 clip_time = 0 t5_time = 0 dit_time = 0 vae_time = 0 +scheduler_time = 0 p1_time = 0 p2_time = 0 p3_time = 0 @@ -34,7 +35,50 @@ p3_time = 0 class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): def compile_aie_model(self): - pass + if self.is_init: + return + size = self.args.batch_size + batch_size = self.args.batch_size * 2 + tail = f"_{self.args.height}x{self.args.width}" + + vae_compiled_path = os.path.join(self.args.output_dir, + f"vae/vae_bs{size}_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + scheduler_compiled_path = os.path.join(self.args.output_dir, + f"scheduler/scheduler_bs{size}_compile{tail}.ts") + self.compiled_scheduler = torch.jit.load(scheduler_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, + f"clip/clip_bs{size}_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, + f"clip/clip2_bs{size}_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + clip3_compiled_path = os.path.join(self.args.output_dir, + f"clip/clip3_bs{size}_compile{tail}.ts") + self.compiled_clip_model_3 = torch.jit.load(clip3_compiled_path).eval() + + dit_cache_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{batch_size}_0_compile{tail}.ts") + self.compiled_dit_cache_model = torch.jit.load(dit_cache_compiled_path).eval() + + if self.args.use_cache: + dit_skip_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{batch_size}_1_compile{tail}.ts") + self.compiled_dit_skip_model = torch.jit.load(dit_skip_compiled_path).eval() + + dit_cache_end_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{size}_0_compile{tail}.ts") + self.compiled_dit_cache_end_model = torch.jit.load(dit_cache_end_compiled_path).eval() + + dit_skip_end_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{size}_1_compile{tail}.ts") + self.compiled_dit_skip_end_model = torch.jit.load(dit_skip_end_compiled_path).eval() + + self.is_init = True @torch.no_grad() def forward( @@ -125,12 +169,11 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): p1_time += (time.time() - start) start1 = time.time() - if self.do_classifier_free_guidance and not self.use_parallel_inferencing: + prompt_embeds_origin = prompt_embeds.clone() + pooled_prompt_embeds_origin = pooled_prompt_embeds.clone() + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) - else: - prompt_embeds, prompt_embeds_1 = negative_prompt_embeds, prompt_embeds - pooled_prompt_embeds, pooled_prompt_embeds_1 = negative_pooled_prompt_embeds, pooled_prompt_embeds # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) @@ -153,51 +196,118 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): # 6. Denoising loop global dit_time global vae_time + global scheduler_time + global dit_1_time + global dit_2_time + global dit_3_time + global dit_2_end_time + global dit_3_end_time + + skip_flag_true = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') + skip_flag_false = torch.zeros([1], dtype=torch.long).to(f'npu:{self.device_0}') + + delta_cache = torch.zeros([2, 4096, 1536], dtype=torch.float32) + delta_cache_hidden = torch.zeros([2, 154, 1536], dtype=torch.float32) + + cache_interval = cache_dict[1] + step_contrast = cache_dict[3] % 2 + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue - # expand the latents if we are doing classifier free guidance - if not self.use_parallel_inferencing and self.do_classifier_free_guidance: - latent_model_input = torch.cat([latents] * 2) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.to(torch.int64)[None].to(f"npu:{self.device_0}") - else: - latent_model_input = latents - timestep = t.to(torch.int64) - self.dit_bg.infer_asyn([ - latent_model_input.numpy(), - prompt_embeds_1.numpy(), - pooled_prompt_embeds_1.numpy(), - timestep[None].numpy().astype(np.int64) - ]) - timestep_npu = timestep[None].to(f"npu:{self.device_0}") - - latent_model_input_npu = latent_model_input.to(f"npu:{self.device_0}") - prompt_embeds_npu = prompt_embeds.to(f"npu:{self.device_0}") - pooled_prompt_embeds_npu = pooled_prompt_embeds.to(f"npu:{self.device_0}") - start = time.time() - noise_pred = self.compiled_dit_model( - latent_model_input_npu, - prompt_embeds_npu, - pooled_prompt_embeds_npu, - timestep_npu - ).to("cpu") - dit_time += (time.time() - start) + timestep_npu = t.to(torch.int64)[None].to(f'npu:{self.device_0}') + if not self.args.use_cache: + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer(latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_dict, + skip_flag_true, delta_cache, + delta_cache_hidden) + else: + if i < tgate: + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + else: + if i == tgate: + _, delta_cache = delta_cache.chunk(2) + _, delta_cache_hidden = delta_cache_hidden.chunk(2) + latent_model_input = latents + + if i < cache_dict[3]: + start = time.time() + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer(latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_dict, + skip_flag_true, delta_cache, + delta_cache_hidden) + dit_1_time += (time.time() - start) + else: + if i % cache_interval == step_contrast: + if i < tgate: + start = time.time() + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer(latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_dict, + skip_flag_true, + delta_cache, + delta_cache_hidden) + dit_2_time += (time.time() - start) + else: + start = time.time() + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer(latent_model_input, + prompt_embeds_origin, + pooled_prompt_embeds_origin, + timestep_npu, cache_dict, + skip_flag_true, + delta_cache, + delta_cache_hidden) + dit_2_end_time += (time.time() - start) + else: + if i < tgate: + start = time.time() + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer(latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_dict, + skip_flag_false, + delta_cache, + delta_cache_hidden) + dit_3_time += (time.time() - start) + else: + start = time.time() + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer(latent_model_input, + prompt_embeds_origin, + pooled_prompt_embeds_origin, + timestep_npu, cache_dict, + skip_flag_false, + delta_cache, + delta_cache_hidden) + dit_3_time += (time.time() - start) + + dit_time += time.time() - start # perform guidance - if self.do_classifier_free_guidance: - if self.use_parallel_inferencing: - noise_pred_text = torch.from_numpy(self.dit_bg.wait_and_get_outputs()[0]) - else: - noise_pred, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred + self.guidance_scale * (noise_pred_text - noise_pred) + if self.do_classifier_free_guidance and i < tgate: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + start = time.time() + # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + step_index = torch.tensor(i).long() + latents = self.compiled_scheduler( + noise_pred.to(f'npu:{self.device_0}'), + latents.to(f'npu:{self.device_0}'), + step_index[None].to(f'npu:{self.device_0}') + ).to('cpu') + scheduler_time += (time.time() - start) if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): @@ -221,20 +331,21 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - p2_time += (time.time() - start1) + p2_time += time.time() - start1 start2 = time.time() if output_type == "latent": image = latents else: start = time.time() - image = self.compiled_vae_model(latents.to(f"npu:{self.device_0}")).to("cpu") - vae_time += (time.time() - start) + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to("cpu") + vae_time += time.time() - start image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() - p3_time += (time.time() - start2) + + p3_time += time.time() - start2 if not return_dict: return (image,) @@ -248,7 +359,7 @@ def main(): if not os.path.exists(save_dir): os.makedirs(save_dir) - pipe = AIEStableDiffusion3Pipeline.from_pretrained(args.model).to("cpu") + pipe = AIEStableDiffusion3CachePipeline.from_pretrained(args.model).to("cpu") pipe.parser_args(args) pipe.compile_aie_model() if isinstance(args.device, list): @@ -256,6 +367,12 @@ def main(): else: mindietorch.set_device(args.device) + cache_dict = torch.zeros([4], dtype=torch.int64) + cache_list = args.cache_dict.split(',') + cache_dict[0] = int(cache_list[0]) + cache_dict[1] = int(cache_list[1]) + cache_dict[2] = int(cache_list[2]) + cache_dict[3] = int(cache_list[3]) use_time = 0 prompt_loader = PromptLoader(args.prompt_file, args.prompt_file_type, @@ -280,7 +397,8 @@ def main(): prompts, negative_prompt="", num_inference_steps=args.steps, - guidance_scale=7.5, + guidance_scale=args.guidance_scale, + cache_dict=cache_dict ) if i > 4: # do not count the time spent inferring the first 0 to 4 images use_time += time.time() - start_time @@ -296,13 +414,18 @@ def main(): image_info[-1]['images'].append(image_save_path) - infer_num = infer_num - 5 # do not count the time spent inferring the first 5 images - print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" - f"average time: {use_time / infer_num:.3f}s\n") - - if hasattr(pipe, 'device_1'): - if (pipe.dit_bg): - pipe.dit_bg.stop() + print( + f"[info] infer number: {infer_num - 5}; use time: {use_time:.3f}s\n" + f"average time: {use_time / (infer_num - 5):.3f}s\n" + f"clip time: {clip_time / infer_num:.3f}s\n" + f"t5 time: {t5_time / infer_num:.3f}s\n" + f"dit time: {dit_time / infer_num:.3f}s\n" + f"scheduler_time time: {scheduler_time / infer_num:.3f}s\n" + f"vae time: {vae_time / infer_num:.3f}s\n" + f"p1 time: {p1_time / infer_num:.3f}s\n" + f"p2 time: {p2_time / infer_num:.3f}s\n" + f"p3 time: {p3_time / infer_num:.3f}s\n" + ) # Save image information to a json file if os.path.exists(args.info_file_save_path): -- Gitee From c6f9bbde45f970ddc420c0660ea3a9cfdf204d7c Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 06:34:28 +0000 Subject: [PATCH 07/20] update MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py. Signed-off-by: huanghao7 --- .../foundation/stable_diffusion_3/transformer_sd3_patch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py index 88033355c6..8556cbac2b 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py @@ -23,8 +23,8 @@ def main(): diffusers_version = diffusers.__version__ assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" - result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", - "attention_processor.patch"], capture_output=True, text=True) + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/transformers/transformer_sd3.py", + "transformer_sd3.patch"], capture_output=True, text=True) if result.returncode != 0: logging.error("Patch failed, error message: s%", result.stderr) -- Gitee From 72740ce4714abbd3e62fb191516bbe7603a83d03 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 06:35:49 +0000 Subject: [PATCH 08/20] update MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py. Signed-off-by: huanghao7 --- .../built-in/foundation/stable_diffusion_3/attention_patch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py index 97585e6af7..2d289cb99e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py @@ -22,8 +22,8 @@ def main(): diffusers_version = diffusers.__version__ assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" - result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", - "attention_processor.patch"], capture_output=True, text=True) + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention.py", + "attention.patch"], capture_output=True, text=True) if result.returncode != 0: logging.error("Patch failed, error message: s%", result.stderr) -- Gitee From 6e6eb77c38f0824fab413eee3d0d8632e4587672 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 06:40:13 +0000 Subject: [PATCH 09/20] update MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py. Signed-off-by: huanghao7 --- .../stable_diffusion_3/export_model.py | 133 +++++++++++++++++- 1 file changed, 128 insertions(+), 5 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py index 0b0131d9af..33609258de 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py @@ -17,6 +17,7 @@ import argparse from argparse import Namespace import torch +from diffusers import FlowMatchEulerDiscreteScheduler from diffusers import StableDiffusion3Pipeline import mindietorch from compile_model import * @@ -66,7 +67,7 @@ def parse_arguments() -> Namespace: ) parser.add_argument("-bs", "--batch_size", type=int, default=1, help="Batch size.") parser.add_argument("-steps", "--steps", type=int, default=28, help="steps.") - parser.add_argument("-guid", "--guidance_scale", type=float, default=5.0, help="guidance_scale") + parser.add_argument("-guid", "--guidance_scale", type=float, default=7.0, help="guidance_scale") parser.add_argument("--use_cache", action="store_true", help="Use cache during inference.") parser.add_argument("-p", "--parallel", action="store_true", help="Export the unet of bs=1 for parallel inferencing.") @@ -89,6 +90,12 @@ def parse_arguments() -> Namespace: type=int, help="image width" ) + parser.add_argument( + "--cache_dict", + type=str, + default="1,2,20,10", + help="Steps to use cache data." + ) return parser.parse_args() @@ -254,15 +261,131 @@ def export_vae(sd_pipeline, args): logging.info("vae_compiled_path already exists.") +def trace_scheduler(sd_pipeline, args, scheduler_pt_path): + batch_size = args.batch_size + if not os.path.exists(scheduler_pt_path): + dummy_input = ( + torch.randn([batch_size, 16, 128, 128], dtype=torch.float32), + torch.randn([batch_size, 16, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64) + ) + scheduler = FlowMatchEulerDiscreteScheduler.from_config(sd_pipeline.scheduler.config) + scheduler.set_timesteps(args.steps, device="cpu") + + new_scheduler = Scheduler() + new_scheduler.eval() + torch.jit.trace(new_scheduler, dummy_input).save(scheduler_pt_path) + + def export_scheduler(sd_pipeline, args): - pass + print("Exporting the scheduler...") + scheduler_path = os.path.join(args.output_dir, "scheduler") + if not os.path.exists(scheduler_path): + os.makedirs(scheduler_path, mode=0o744) + batch_size = args.batch_size + height_size, width_size = args.height // 8, args.width // 8 + scheduler_pt_path = os.path.join(scheduler_path, f"scheduler_bs{batch_size}.pt") + scheduler_compiled_path = os.path.join(scheduler_path, + f"scheduler_bs{batch_size}_compile_{args.height}x{args.width}.ts") + in_channels = 16 + + # trace + trace_scheduler(sd_pipeline, args, scheduler_pt_path) + + # compile + if not os.path.exists(scheduler_compiled_path): + model = torch.jit.load(scheduler_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64) + ] + compile_scheduler(model, inputs, scheduler_compiled_path, args.soc_version) + + +def export_dit_cache(sd_pipeline, args, if_skip, flag=""): + print("Exporting the dit_cache...") + cache_dict = torch.zeros([4], dtype=torch.int64) + cache_list = args.cache_dict.split(',') + cache_dict[0] = int(cache_list[0]) + cache_dict[1] = int(cache_list[1]) + cache_dict[2] = int(cache_list[2]) + cache_dict[3] = int(cache_list[3]) + dit_path = os.path.join(args.output_dir, "dit") + if not os.path.exists(dit_path): + os.makedirs(dit_path, mode=0o640) + + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + dit_model = sd_pipeline.transformer + if args.parallel or flag == "end": + batch_size = args.batch_size + else: + batch_size = args.batch_size * 2 + sample_size = dit_model.config.sample_size + in_channels = dit_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings * 2 + dit_cache_pt_path = os.path.join(dit_path, f"dit_bs{batch_size}_{if_skip}.pt") + dit_cache_compiled_path = os.path.join(dit_path, + f"dit_bs{batch_size}_{if_skip}_compile_{args.height}x{args.width}.ts") + # trace + if not os.path.exists(dit_cache_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([batch_size, max_position_embeddings, encoder_hidden_size * 2], dtype=torch.float32), + torch.ones([batch_size, encoder_hidden_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + cache_dict, + torch.tensor([if_skip], dtype=torch.int64), + torch.ones([batch_size, 4096, 1536], dtype=torch.float32), + torch.ones([batch_size, 154, 1536], dtype=torch.float32), + ) + print("dummy_input.shape:") + for ele in dummy_input: + if isinstance(ele, torch.Tensor): + print(ele.shape) + dit = DiTExportCache(dit_model).eval() + torch.jit.trace(dit, dummy_input).save(dit_cache_pt_path) -def export(args): - pipeline = StableDiffusion3Pipeline.from_pretrained(args.model).to('cpu') + # compile + if not os.path.exists(dit_cache_compiled_path): + model = torch.jit.load(dit_cache_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size * 2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((4,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, 4096, 1536), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 154, 1536), + dtype=mindietorch.dtype.FLOAT), + ] + compile_dit_cache(model, inputs, dit_cache_compiled_path, args.soc_version) + + +def export(args) -> None: + pipeline = StableDiffusion3Pipeline.from_pretrained(args.model).to("cpu") export_clip(pipeline, args) - export_dit(pipeline, args) + if args.use_cache: + export_dit_cache(pipeline, args, 0) + export_dit_cache(pipeline, args, 1) + if "B" in args.soc_version: + export_dit_cache(pipeline, args, 0, "end") + export_dit_cache(pipeline, args, 1, "end") + else: + export_dit(pipeline, args) export_vae(pipeline, args) + export_scheduler(pipeline, args) def main(args): -- Gitee From 50cfcad016a3a6c007d29a4fe6bb45f6188124f1 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 07:04:46 +0000 Subject: [PATCH 10/20] update MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py. Signed-off-by: huanghao7 --- .../stable_diffusion_3/compile_model.py | 70 +++++++++++++++++-- 1 file changed, 65 insertions(+), 5 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py index ab3e74e749..7bd43f2136 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py @@ -18,6 +18,11 @@ from typing import List import mindietorch from mindietorch import _enums +SCHEDULER_SIGMAS = torch.tensor([1.0000, 0.9874, 0.9741, 0.9601, 0.9454, 0.9298, 0.9133, 0.8959, 0.8774, + 0.8577, 0.8367, 0.8143, 0.7904, 0.7647, 0.7371, 0.7073, 0.6751, 0.6402, + 0.6022, 0.5606, 0.5151, 0.4649, 0.4093, 0.3474, 0.2780, 0.1998, 0.1109, + 0.0089, 0.0000], dtype=torch.float32) + @dataclass class CompileParam: @@ -69,11 +74,48 @@ class VaeExport(torch.nn.Module): image = self.vae_model.decode(latents, return_dict=False)[0] return image + def compile_vae(model, inputs, vae_compiled_path, soc_version): vae_param = CompileParam(inputs, soc_version) common_compile(model, vae_compiled_path, vae_param) +class Scheduler(torch.nn.Module): + def __init__(self): + super(Scheduler, self).__init__() + self.sigmas = SCHEDULER_SIGMAS + + def forward( + self, + model_output: torch.FloatTensor, + sample: torch.FloatTensor, + step_index: torch.LongTensor + ): + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + sigma = self.sigmas[step_index] + + sigma_hat = sigma + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[step_index + 1] - sigma_hat + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + return prev_sample + + +def compile_scheduler(model, inputs, scheduler_compiled_path, soc_version): + scheduler_param = CompileParam(inputs, soc_version, True, True, False) + common_compile(model, scheduler_compiled_path, scheduler_param) + + class DiTExport(torch.nn.Module): def __init__(self, dit_model): super().__init__() @@ -95,10 +137,28 @@ def compile_dit(model, inputs, dit_compiled_path, soc_version): common_compile(model, dit_compiled_path, dit_param) -class Scheduler(torch.nn.Module): - def __init__(self, ): +class DiTExportCache(torch.nn.Module): + def __init__(self, dit_cache_model): super().__init__() - self.sigmas = None + self.dit_cache_model = dit_cache_model + + def forward( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + cache_dict, + if_skip: int = 0, + delta_cache: torch.FloatTensor = None, + delta_cache_hidden: torch.FloatTensor = None, + use_cache: bool = True, + ): + return self.dit_cache_model(hidden_states, encoder_hidden_states, pooled_projections, timestep, + cache_dict, if_skip, delta_cache, delta_cache_hidden, use_cache, + joint_attention_kwargs=None, return_dict=False) + - def forward(self): - pass +def compile_dit_cache(model, inputs, dit_cache_compiled_path, soc_version): + dit_cache_param = CompileParam(inputs, soc_version, True, False, True) + common_compile(model, dit_cache_compiled_path, dit_cache_param) -- Gitee From 4dafeee1b0fc4083e0734f4b68b74c8100e8f7db Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 07:41:55 +0000 Subject: [PATCH 11/20] update MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch. Signed-off-by: huanghao7 --- .../attention_processor.patch | 80 ++++++++++++++++--- 1 file changed, 70 insertions(+), 10 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch index dc22a411ea..8897fabdbb 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch @@ -1,9 +1,41 @@ ---- attention_processor.py 2024-08-14 12:41:10.528000000 +0000 -+++ attention_processor.py 2024-08-14 12:42:25.620000000 +0000 -@@ -1082,6 +1082,29 @@ +--- attention_processor.py 2024-09-04 09:22:16.048000000 +0000 ++++ attention_processor.py 2024-09-04 09:57:55.928000000 +0000 +@@ -115,6 +115,7 @@ + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, ++ layer_idx: int = 0, + out_dim: int = None, + context_pre_only=None, + ): +@@ -132,13 +133,14 @@ + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only ++ self.layer_idx = layer_idx + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk +- self.scale = dim_head**-0.5 if self.scale_qk else 1.0 ++ self.scale = dim_head ** -0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation +@@ -561,6 +563,7 @@ + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, ++ layer_idx=self.layer_idx, + **cross_attention_kwargs, + ) + +@@ -1082,6 +1085,29 @@ return hidden_states - - + + +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, + scale=None) -> torch.Tensor: + # Efficient implementation equivalent to the following: @@ -20,7 +52,7 @@ + if attn_mask.dtype == torch.bool: + attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: -+ attn_bias += attn_mask ++ attn_bias = attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) @@ -29,13 +61,41 @@ + class JointAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" - -@@ -1132,7 +1155,7 @@ + +@@ -1095,6 +1121,7 @@ + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, ++ layer_idx=0, + *args, + **kwargs, + ) -> torch.FloatTensor: +@@ -1112,7 +1139,19 @@ + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. ++ # use todo + query = attn.to_q(hidden_states) ++ if layer_idx <= 11: ++ hidden_dim = hidden_states.shape[-1] ++ cur_h = int(math.sqrt(hidden_states.shape[1])) ++ cur_w = cur_h ++ hidden_states = hidden_states.transpose(1, 2).view(batch_size, hidden_dim, cur_h, cur_w) ++ new_h = int(cur_h / 2.2) ++ new_w = int(cur_w / 2.2) ++ item = F.interpolate(hidden_states, size=(new_h, new_w), ++ mode='bilinear') ++ item = item.permute(0, 2, 3, 1) ++ hidden_states = item.reshape(batch_size, new_h * new_w, -1) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + +@@ -1132,7 +1171,7 @@ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - + - hidden_states = hidden_states = F.scaled_dot_product_attention( + hidden_states = hidden_states = scaled_dot_product_attention( query, key, value, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) \ No newline at end of file + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) -- Gitee From aeff65e2268c5a41168cc7cda22217e5991c3808 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 08:14:08 +0000 Subject: [PATCH 12/20] update MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py. Signed-off-by: huanghao7 --- .../built-in/foundation/stable_diffusion_3/export_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py index 33609258de..fb18b7eddc 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py @@ -302,7 +302,7 @@ def export_scheduler(sd_pipeline, args): dtype=mindietorch.dtype.FLOAT), mindietorch.Input((1,), dtype=mindietorch.dtype.INT64) ] - compile_scheduler(model, inputs, scheduler_compiled_path, args.soc_version) + compile_scheduler(model, inputs, scheduler_compiled_path, args.soc) def export_dit_cache(sd_pipeline, args, if_skip, flag=""): @@ -370,7 +370,7 @@ def export_dit_cache(sd_pipeline, args, if_skip, flag=""): mindietorch.Input((batch_size, 154, 1536), dtype=mindietorch.dtype.FLOAT), ] - compile_dit_cache(model, inputs, dit_cache_compiled_path, args.soc_version) + compile_dit_cache(model, inputs, dit_cache_compiled_path, args.soc) def export(args) -> None: @@ -379,7 +379,7 @@ def export(args) -> None: if args.use_cache: export_dit_cache(pipeline, args, 0) export_dit_cache(pipeline, args, 1) - if "B" in args.soc_version: + if "B" in args.soc: export_dit_cache(pipeline, args, 0, "end") export_dit_cache(pipeline, args, 1, "end") else: -- Gitee From 4dbfc988d75f2195574996205dd906a4e45d45e3 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 08:36:43 +0000 Subject: [PATCH 13/20] update MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py. Signed-off-by: huanghao7 --- .../stable_diffusion_3/stable_diffusion3_pipeline.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py index 4968b6832d..fd2c2d66b8 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py @@ -725,6 +725,11 @@ def parse_arguments(): type=int, help="image width" ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) parser.add_argument( "--cache_dict", default="1,2,20,10", -- Gitee From 10e8908e583a5caf921ea063615f0c67a24cb14f Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 08:40:41 +0000 Subject: [PATCH 14/20] update stable_diffusion_3/stable_diffusion3_pipeline_cache.py. Signed-off-by: huanghao7 --- .../stable_diffusion3_pipeline_cache.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py index d3b4c483ff..832f96fad5 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -58,7 +58,7 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() clip3_compiled_path = os.path.join(self.args.output_dir, - f"clip/clip3_bs{size}_compile{tail}.ts") + f"clip/t5_bs{size}_compile{tail}.ts") self.compiled_clip_model_3 = torch.jit.load(clip3_compiled_path).eval() dit_cache_compiled_path = os.path.join(self.args.output_dir, @@ -359,13 +359,13 @@ def main(): if not os.path.exists(save_dir): os.makedirs(save_dir) - pipe = AIEStableDiffusion3CachePipeline.from_pretrained(args.model).to("cpu") - pipe.parser_args(args) - pipe.compile_aie_model() if isinstance(args.device, list): mindietorch.set_device(args.device[0]) else: mindietorch.set_device(args.device) + pipe = AIEStableDiffusion3CachePipeline.from_pretrained(args.model).to("cpu") + pipe.parser_args(args) + pipe.compile_aie_model() cache_dict = torch.zeros([4], dtype=torch.int64) cache_list = args.cache_dict.split(',') -- Gitee From 951254784393e0b55c65b362ef47e15f3d48b393 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 08:43:18 +0000 Subject: [PATCH 15/20] update stable_diffusion_3/stable_diffusion3_pipeline_cache.py. Signed-off-by: huanghao7 --- .../stable_diffusion_3/stable_diffusion3_pipeline_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py index 832f96fad5..2d14d313b8 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -57,9 +57,9 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): f"clip/clip2_bs{size}_compile{tail}.ts") self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() - clip3_compiled_path = os.path.join(self.args.output_dir, + t5_compiled_path = os.path.join(self.args.output_dir, f"clip/t5_bs{size}_compile{tail}.ts") - self.compiled_clip_model_3 = torch.jit.load(clip3_compiled_path).eval() + self.compiled_t5_model = torch.jit.load(t5_compiled_path).eval() dit_cache_compiled_path = os.path.join(self.args.output_dir, f"dit/dit_bs{batch_size}_0_compile{tail}.ts") -- Gitee From c74a26780adc5921411853876a39f589f074df91 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 08:48:37 +0000 Subject: [PATCH 16/20] update stable_diffusion_3/stable_diffusion3_pipeline_cache.py. Signed-off-by: huanghao7 --- .../stable_diffusion_3/stable_diffusion3_pipeline_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py index 2d14d313b8..8328b9d7dc 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -397,7 +397,7 @@ def main(): prompts, negative_prompt="", num_inference_steps=args.steps, - guidance_scale=args.guidance_scale, + guidance_scale=7.0, cache_dict=cache_dict ) if i > 4: # do not count the time spent inferring the first 0 to 4 images -- Gitee From 66e5c322a83d565d6f9418ffbc721f67b3ec2787 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 08:56:11 +0000 Subject: [PATCH 17/20] update stable_diffusion_3/stable_diffusion3_pipeline_cache.py. Signed-off-by: huanghao7 --- .../stable_diffusion3_pipeline_cache.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py index 8328b9d7dc..19d2f61b9a 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -80,6 +80,24 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): self.is_init = True + @torch.no_grad() + def dit_infer(self, latent_model_input, prompt_embeds, pooled_prompt_embeds, timestep_npu, cache_dict, skip_flag, + delta_cache, delta_cache_hidden): + (noise_pred, delta_cache, delta_cache_hidden) = self.compiled_dit_cache_model( + latent_model_input.to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + pooled_prompt_embeds.to(f'npu:{self.device_0}'), + timestep_npu, + cache_dict.to(f'npu:{self.device_0}'), + skip_flag, + delta_cache.to(f'npu:{self.device_0}'), + delta_cache_hidden.to(f'npu:{self.device_0}'), + ) + noise_pred = noise_pred.to("cpu") + delta_cache = delta_cache.to("cpu") + delta_cache_hidden = delta_cache_hidden.to("cpu") + return (noise_pred, delta_cache, delta_cache_hidden) + @torch.no_grad() def forward( self, -- Gitee From 7ade50b9f86ac47cbae527ff436b9f69fa40c662 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 09:00:12 +0000 Subject: [PATCH 18/20] update stable_diffusion_3/stable_diffusion3_pipeline_cache.py. Signed-off-by: huanghao7 --- .../stable_diffusion_3/stable_diffusion3_pipeline_cache.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py index 19d2f61b9a..1881a001ef 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -31,6 +31,11 @@ scheduler_time = 0 p1_time = 0 p2_time = 0 p3_time = 0 +dit_1_time = 0 +dit_2_time = 0 +dit_3_time = 0 +dit_2_end_time = 0 +dit_3_end_time = 0 class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): -- Gitee From e51ff243552a1b6a1cda102bac86d59063c78ad5 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 09:28:41 +0000 Subject: [PATCH 19/20] update stable_diffusion_3/stable_diffusion3_pipeline_cache.py. Signed-off-by: huanghao7 --- .../stable_diffusion3_pipeline_cache.py | 128 ++++++++---------- 1 file changed, 59 insertions(+), 69 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py index 1881a001ef..f3d5bda62e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -23,19 +23,12 @@ from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import r from stable_diffusion3_pipeline import PromptLoader, AIEStableDiffusion3Pipeline, parse_arguments tgate = 20 -clip_time = 0 -t5_time = 0 dit_time = 0 vae_time = 0 scheduler_time = 0 p1_time = 0 p2_time = 0 p3_time = 0 -dit_1_time = 0 -dit_2_time = 0 -dit_3_time = 0 -dit_2_end_time = 0 -dit_3_end_time = 0 class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): @@ -63,7 +56,7 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() t5_compiled_path = os.path.join(self.args.output_dir, - f"clip/t5_bs{size}_compile{tail}.ts") + f"clip/t5_bs{size}_compile{tail}.ts") self.compiled_t5_model = torch.jit.load(t5_compiled_path).eval() dit_cache_compiled_path = os.path.join(self.args.output_dir, @@ -86,9 +79,9 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): self.is_init = True @torch.no_grad() - def dit_infer(self, latent_model_input, prompt_embeds, pooled_prompt_embeds, timestep_npu, cache_dict, skip_flag, - delta_cache, delta_cache_hidden): - (noise_pred, delta_cache, delta_cache_hidden) = self.compiled_dit_cache_model( + def dit_infer(self, compiled_model, latent_model_input, prompt_embeds, pooled_prompt_embeds, timestep_npu, + cache_dict, skip_flag, delta_cache, delta_cache_hidden): + (noise_pred, delta_cache, delta_cache_hidden) = compiled_model( latent_model_input.to(f'npu:{self.device_0}'), prompt_embeds.to(f'npu:{self.device_0}'), pooled_prompt_embeds.to(f'npu:{self.device_0}'), @@ -220,11 +213,6 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): global dit_time global vae_time global scheduler_time - global dit_1_time - global dit_2_time - global dit_3_time - global dit_2_end_time - global dit_3_end_time skip_flag_true = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') skip_flag_false = torch.zeros([1], dtype=torch.long).to(f'npu:{self.device_0}') @@ -245,12 +233,14 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): if not self.args.use_cache: # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer(latent_model_input, - prompt_embeds, - pooled_prompt_embeds, - timestep_npu, cache_dict, - skip_flag_true, delta_cache, - delta_cache_hidden) + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_dict, + skip_flag_true, delta_cache, + delta_cache_hidden) else: if i < tgate: latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents @@ -261,59 +251,59 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): latent_model_input = latents if i < cache_dict[3]: - start = time.time() - (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer(latent_model_input, - prompt_embeds, - pooled_prompt_embeds, - timestep_npu, cache_dict, - skip_flag_true, delta_cache, - delta_cache_hidden) - dit_1_time += (time.time() - start) + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_dict, + skip_flag_true, delta_cache, + delta_cache_hidden) else: if i % cache_interval == step_contrast: if i < tgate: - start = time.time() - (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer(latent_model_input, - prompt_embeds, - pooled_prompt_embeds, - timestep_npu, cache_dict, - skip_flag_true, - delta_cache, - delta_cache_hidden) - dit_2_time += (time.time() - start) + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_dict, + skip_flag_true, + delta_cache, + delta_cache_hidden) else: - start = time.time() - (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer(latent_model_input, - prompt_embeds_origin, - pooled_prompt_embeds_origin, - timestep_npu, cache_dict, - skip_flag_true, - delta_cache, - delta_cache_hidden) - dit_2_end_time += (time.time() - start) + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_end_model, + latent_model_input, + prompt_embeds_origin, + pooled_prompt_embeds_origin, + timestep_npu, cache_dict, + skip_flag_true, + delta_cache, + delta_cache_hidden) else: if i < tgate: - start = time.time() - (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer(latent_model_input, - prompt_embeds, - pooled_prompt_embeds, - timestep_npu, cache_dict, - skip_flag_false, - delta_cache, - delta_cache_hidden) - dit_3_time += (time.time() - start) + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_skip_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_dict, + skip_flag_false, + delta_cache, + delta_cache_hidden) else: - start = time.time() - (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer(latent_model_input, - prompt_embeds_origin, - pooled_prompt_embeds_origin, - timestep_npu, cache_dict, - skip_flag_false, - delta_cache, - delta_cache_hidden) - dit_3_time += (time.time() - start) - - dit_time += time.time() - start + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_skip_end_model, + latent_model_input, + prompt_embeds_origin, + pooled_prompt_embeds_origin, + timestep_npu, cache_dict, + skip_flag_false, + delta_cache, + delta_cache_hidden) + + dit_time += (time.time() - start) # perform guidance if self.do_classifier_free_guidance and i < tgate: @@ -423,7 +413,7 @@ def main(): guidance_scale=7.0, cache_dict=cache_dict ) - if i > 4: # do not count the time spent inferring the first 0 to 4 images + if i > 4: # do not count the time spent inferring the first 0 to 4 images use_time += time.time() - start_time for j in range(n_prompts): -- Gitee From 05a68446435f5d1d0c538101e83794539c2d6d75 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 09:30:00 +0000 Subject: [PATCH 20/20] update stable_diffusion_3/stable_diffusion3_pipeline_cache.py. Signed-off-by: huanghao7 --- .../stable_diffusion_3/stable_diffusion3_pipeline_cache.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py index f3d5bda62e..f223729f07 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -430,8 +430,6 @@ def main(): print( f"[info] infer number: {infer_num - 5}; use time: {use_time:.3f}s\n" f"average time: {use_time / (infer_num - 5):.3f}s\n" - f"clip time: {clip_time / infer_num:.3f}s\n" - f"t5 time: {t5_time / infer_num:.3f}s\n" f"dit time: {dit_time / infer_num:.3f}s\n" f"scheduler_time time: {scheduler_time / infer_num:.3f}s\n" f"vae time: {vae_time / infer_num:.3f}s\n" -- Gitee