diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md index 4a807769b94dc98ee19a0b8941e79c69c344c339..c29b64bdff7bf59f2b7c75914b06a7de9cb2e9d1 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md @@ -80,7 +80,7 @@ ``` - 执行命令: ```bash - python3 attention_patch.py + python3 attention_processor_patch.py ``` ## 模型推理 diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch new file mode 100644 index 0000000000000000000000000000000000000000..ce183bd9141104a16b4e2e14173b1782f3b10bcb --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch @@ -0,0 +1,19 @@ +--- attention.py 2024-09-04 09:22:31.768000000 +0000 ++++ attention.py 2024-09-04 09:17:12.680000000 +0000 +@@ -100,7 +100,7 @@ + processing of `context` conditions. + """ + +- def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False): ++ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False, layer_idx=0): + super().__init__() + + self.context_pre_only = context_pre_only +@@ -134,6 +134,7 @@ + context_pre_only=context_pre_only, + bias=True, + processor=processor, ++ layer_idx=layer_idx + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py index 97585e6af71fd0df9f9af7ba99eff923e3352b15..2d289cb99ed59c8ea3ed16db71c9333b619a44a6 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py @@ -22,8 +22,8 @@ def main(): diffusers_version = diffusers.__version__ assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" - result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", - "attention_processor.patch"], capture_output=True, text=True) + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention.py", + "attention.patch"], capture_output=True, text=True) if result.returncode != 0: logging.error("Patch failed, error message: s%", result.stderr) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch index dc22a411eaf2ee4e13f6ea67a13981819e5c559a..8897fabdbb8a53585072ab7b1a43cc4029c2b3ca 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch @@ -1,9 +1,41 @@ ---- attention_processor.py 2024-08-14 12:41:10.528000000 +0000 -+++ attention_processor.py 2024-08-14 12:42:25.620000000 +0000 -@@ -1082,6 +1082,29 @@ +--- attention_processor.py 2024-09-04 09:22:16.048000000 +0000 ++++ attention_processor.py 2024-09-04 09:57:55.928000000 +0000 +@@ -115,6 +115,7 @@ + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, ++ layer_idx: int = 0, + out_dim: int = None, + context_pre_only=None, + ): +@@ -132,13 +133,14 @@ + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only ++ self.layer_idx = layer_idx + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk +- self.scale = dim_head**-0.5 if self.scale_qk else 1.0 ++ self.scale = dim_head ** -0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation +@@ -561,6 +563,7 @@ + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, ++ layer_idx=self.layer_idx, + **cross_attention_kwargs, + ) + +@@ -1082,6 +1085,29 @@ return hidden_states - - + + +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, + scale=None) -> torch.Tensor: + # Efficient implementation equivalent to the following: @@ -20,7 +52,7 @@ + if attn_mask.dtype == torch.bool: + attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: -+ attn_bias += attn_mask ++ attn_bias = attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) @@ -29,13 +61,41 @@ + class JointAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" - -@@ -1132,7 +1155,7 @@ + +@@ -1095,6 +1121,7 @@ + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, ++ layer_idx=0, + *args, + **kwargs, + ) -> torch.FloatTensor: +@@ -1112,7 +1139,19 @@ + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. ++ # use todo + query = attn.to_q(hidden_states) ++ if layer_idx <= 11: ++ hidden_dim = hidden_states.shape[-1] ++ cur_h = int(math.sqrt(hidden_states.shape[1])) ++ cur_w = cur_h ++ hidden_states = hidden_states.transpose(1, 2).view(batch_size, hidden_dim, cur_h, cur_w) ++ new_h = int(cur_h / 2.2) ++ new_w = int(cur_w / 2.2) ++ item = F.interpolate(hidden_states, size=(new_h, new_w), ++ mode='bilinear') ++ item = item.permute(0, 2, 3, 1) ++ hidden_states = item.reshape(batch_size, new_h * new_w, -1) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + +@@ -1132,7 +1171,7 @@ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - + - hidden_states = hidden_states = F.scaled_dot_product_attention( + hidden_states = hidden_states = scaled_dot_product_attention( query, key, value, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) \ No newline at end of file + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..97585e6af71fd0df9f9af7ba99eff923e3352b15 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py @@ -0,0 +1,32 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", + "attention_processor.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main() diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py index b48994f182031291c46b0fbf817b6b15da438d93..7bd43f2136d0a4e05f7a9b314ec71eced12f4bb2 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py @@ -18,6 +18,11 @@ from typing import List import mindietorch from mindietorch import _enums +SCHEDULER_SIGMAS = torch.tensor([1.0000, 0.9874, 0.9741, 0.9601, 0.9454, 0.9298, 0.9133, 0.8959, 0.8774, + 0.8577, 0.8367, 0.8143, 0.7904, 0.7647, 0.7371, 0.7073, 0.6751, 0.6402, + 0.6022, 0.5606, 0.5151, 0.4649, 0.4093, 0.3474, 0.2780, 0.1998, 0.1109, + 0.0089, 0.0000], dtype=torch.float32) + @dataclass class CompileParam: @@ -69,11 +74,48 @@ class VaeExport(torch.nn.Module): image = self.vae_model.decode(latents, return_dict=False)[0] return image + def compile_vae(model, inputs, vae_compiled_path, soc_version): vae_param = CompileParam(inputs, soc_version) common_compile(model, vae_compiled_path, vae_param) +class Scheduler(torch.nn.Module): + def __init__(self): + super(Scheduler, self).__init__() + self.sigmas = SCHEDULER_SIGMAS + + def forward( + self, + model_output: torch.FloatTensor, + sample: torch.FloatTensor, + step_index: torch.LongTensor + ): + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + sigma = self.sigmas[step_index] + + sigma_hat = sigma + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[step_index + 1] - sigma_hat + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + return prev_sample + + +def compile_scheduler(model, inputs, scheduler_compiled_path, soc_version): + scheduler_param = CompileParam(inputs, soc_version, True, True, False) + common_compile(model, scheduler_compiled_path, scheduler_param) + + class DiTExport(torch.nn.Module): def __init__(self, dit_model): super().__init__() @@ -93,3 +135,30 @@ class DiTExport(torch.nn.Module): def compile_dit(model, inputs, dit_compiled_path, soc_version): dit_param = CompileParam(inputs, soc_version) common_compile(model, dit_compiled_path, dit_param) + + +class DiTExportCache(torch.nn.Module): + def __init__(self, dit_cache_model): + super().__init__() + self.dit_cache_model = dit_cache_model + + def forward( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + cache_dict, + if_skip: int = 0, + delta_cache: torch.FloatTensor = None, + delta_cache_hidden: torch.FloatTensor = None, + use_cache: bool = True, + ): + return self.dit_cache_model(hidden_states, encoder_hidden_states, pooled_projections, timestep, + cache_dict, if_skip, delta_cache, delta_cache_hidden, use_cache, + joint_attention_kwargs=None, return_dict=False) + + +def compile_dit_cache(model, inputs, dit_cache_compiled_path, soc_version): + dit_cache_param = CompileParam(inputs, soc_version, True, False, True) + common_compile(model, dit_cache_compiled_path, dit_cache_param) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py index 1cda87042f27c66fb2c5582ae190ab99564b9fcb..fb18b7eddcad63aaff056037f2415ccd7f5145cc 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py @@ -17,6 +17,7 @@ import argparse from argparse import Namespace import torch +from diffusers import FlowMatchEulerDiscreteScheduler from diffusers import StableDiffusion3Pipeline import mindietorch from compile_model import * @@ -66,7 +67,7 @@ def parse_arguments() -> Namespace: ) parser.add_argument("-bs", "--batch_size", type=int, default=1, help="Batch size.") parser.add_argument("-steps", "--steps", type=int, default=28, help="steps.") - parser.add_argument("-guid", "--guidance_scale", type=float, default=5.0, help="guidance_scale") + parser.add_argument("-guid", "--guidance_scale", type=float, default=7.0, help="guidance_scale") parser.add_argument("--use_cache", action="store_true", help="Use cache during inference.") parser.add_argument("-p", "--parallel", action="store_true", help="Export the unet of bs=1 for parallel inferencing.") @@ -89,6 +90,12 @@ def parse_arguments() -> Namespace: type=int, help="image width" ) + parser.add_argument( + "--cache_dict", + type=str, + default="1,2,20,10", + help="Steps to use cache data." + ) return parser.parse_args() @@ -254,11 +261,131 @@ def export_vae(sd_pipeline, args): logging.info("vae_compiled_path already exists.") -def export(args): - pipeline = StableDiffusion3Pipeline.from_pretrained(args.model).to('cpu') +def trace_scheduler(sd_pipeline, args, scheduler_pt_path): + batch_size = args.batch_size + if not os.path.exists(scheduler_pt_path): + dummy_input = ( + torch.randn([batch_size, 16, 128, 128], dtype=torch.float32), + torch.randn([batch_size, 16, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64) + ) + scheduler = FlowMatchEulerDiscreteScheduler.from_config(sd_pipeline.scheduler.config) + scheduler.set_timesteps(args.steps, device="cpu") + + new_scheduler = Scheduler() + new_scheduler.eval() + torch.jit.trace(new_scheduler, dummy_input).save(scheduler_pt_path) + + +def export_scheduler(sd_pipeline, args): + print("Exporting the scheduler...") + scheduler_path = os.path.join(args.output_dir, "scheduler") + if not os.path.exists(scheduler_path): + os.makedirs(scheduler_path, mode=0o744) + batch_size = args.batch_size + height_size, width_size = args.height // 8, args.width // 8 + scheduler_pt_path = os.path.join(scheduler_path, f"scheduler_bs{batch_size}.pt") + scheduler_compiled_path = os.path.join(scheduler_path, + f"scheduler_bs{batch_size}_compile_{args.height}x{args.width}.ts") + in_channels = 16 + + # trace + trace_scheduler(sd_pipeline, args, scheduler_pt_path) + + # compile + if not os.path.exists(scheduler_compiled_path): + model = torch.jit.load(scheduler_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64) + ] + compile_scheduler(model, inputs, scheduler_compiled_path, args.soc) + + +def export_dit_cache(sd_pipeline, args, if_skip, flag=""): + print("Exporting the dit_cache...") + cache_dict = torch.zeros([4], dtype=torch.int64) + cache_list = args.cache_dict.split(',') + cache_dict[0] = int(cache_list[0]) + cache_dict[1] = int(cache_list[1]) + cache_dict[2] = int(cache_list[2]) + cache_dict[3] = int(cache_list[3]) + dit_path = os.path.join(args.output_dir, "dit") + if not os.path.exists(dit_path): + os.makedirs(dit_path, mode=0o640) + + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + dit_model = sd_pipeline.transformer + if args.parallel or flag == "end": + batch_size = args.batch_size + else: + batch_size = args.batch_size * 2 + sample_size = dit_model.config.sample_size + in_channels = dit_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings * 2 + dit_cache_pt_path = os.path.join(dit_path, f"dit_bs{batch_size}_{if_skip}.pt") + dit_cache_compiled_path = os.path.join(dit_path, + f"dit_bs{batch_size}_{if_skip}_compile_{args.height}x{args.width}.ts") + + # trace + if not os.path.exists(dit_cache_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([batch_size, max_position_embeddings, encoder_hidden_size * 2], dtype=torch.float32), + torch.ones([batch_size, encoder_hidden_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + cache_dict, + torch.tensor([if_skip], dtype=torch.int64), + torch.ones([batch_size, 4096, 1536], dtype=torch.float32), + torch.ones([batch_size, 154, 1536], dtype=torch.float32), + ) + print("dummy_input.shape:") + for ele in dummy_input: + if isinstance(ele, torch.Tensor): + print(ele.shape) + dit = DiTExportCache(dit_model).eval() + torch.jit.trace(dit, dummy_input).save(dit_cache_pt_path) + + # compile + if not os.path.exists(dit_cache_compiled_path): + model = torch.jit.load(dit_cache_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size * 2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((4,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, 4096, 1536), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 154, 1536), + dtype=mindietorch.dtype.FLOAT), + ] + compile_dit_cache(model, inputs, dit_cache_compiled_path, args.soc) + + +def export(args) -> None: + pipeline = StableDiffusion3Pipeline.from_pretrained(args.model).to("cpu") export_clip(pipeline, args) - export_dit(pipeline, args) + if args.use_cache: + export_dit_cache(pipeline, args, 0) + export_dit_cache(pipeline, args, 1) + if "B" in args.soc: + export_dit_cache(pipeline, args, 0, "end") + export_dit_cache(pipeline, args, 1, "end") + else: + export_dit(pipeline, args) export_vae(pipeline, args) + export_scheduler(pipeline, args) def main(args): diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py index 5415517f7e0160959cff4fdad0a4456233dcc8e1..fd2c2d66b8438959bafe62de3a84f4ff56fb1cd3 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py @@ -725,6 +725,17 @@ def parse_arguments(): type=int, help="image width" ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--cache_dict", + default="1,2,20,10", + type=str, + help="steps to use cache data" + ) return parser.parse_args() diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..f223729f0763493c7dd2b8ceed32bfbc9328c6d8 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -0,0 +1,451 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union +import torch +import mindietorch +from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps +from stable_diffusion3_pipeline import PromptLoader, AIEStableDiffusion3Pipeline, parse_arguments + +tgate = 20 +dit_time = 0 +vae_time = 0 +scheduler_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 + + +class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): + def compile_aie_model(self): + if self.is_init: + return + size = self.args.batch_size + batch_size = self.args.batch_size * 2 + tail = f"_{self.args.height}x{self.args.width}" + + vae_compiled_path = os.path.join(self.args.output_dir, + f"vae/vae_bs{size}_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + scheduler_compiled_path = os.path.join(self.args.output_dir, + f"scheduler/scheduler_bs{size}_compile{tail}.ts") + self.compiled_scheduler = torch.jit.load(scheduler_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, + f"clip/clip_bs{size}_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, + f"clip/clip2_bs{size}_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + t5_compiled_path = os.path.join(self.args.output_dir, + f"clip/t5_bs{size}_compile{tail}.ts") + self.compiled_t5_model = torch.jit.load(t5_compiled_path).eval() + + dit_cache_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{batch_size}_0_compile{tail}.ts") + self.compiled_dit_cache_model = torch.jit.load(dit_cache_compiled_path).eval() + + if self.args.use_cache: + dit_skip_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{batch_size}_1_compile{tail}.ts") + self.compiled_dit_skip_model = torch.jit.load(dit_skip_compiled_path).eval() + + dit_cache_end_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{size}_0_compile{tail}.ts") + self.compiled_dit_cache_end_model = torch.jit.load(dit_cache_end_compiled_path).eval() + + dit_skip_end_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{size}_1_compile{tail}.ts") + self.compiled_dit_skip_end_model = torch.jit.load(dit_skip_end_compiled_path).eval() + + self.is_init = True + + @torch.no_grad() + def dit_infer(self, compiled_model, latent_model_input, prompt_embeds, pooled_prompt_embeds, timestep_npu, + cache_dict, skip_flag, delta_cache, delta_cache_hidden): + (noise_pred, delta_cache, delta_cache_hidden) = compiled_model( + latent_model_input.to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + pooled_prompt_embeds.to(f'npu:{self.device_0}'), + timestep_npu, + cache_dict.to(f'npu:{self.device_0}'), + skip_flag, + delta_cache.to(f'npu:{self.device_0}'), + delta_cache_hidden.to(f'npu:{self.device_0}'), + ) + noise_pred = noise_pred.to("cpu") + delta_cache = delta_cache.to("cpu") + delta_cache_hidden = delta_cache_hidden.to("cpu") + return (noise_pred, delta_cache, delta_cache_hidden) + + @torch.no_grad() + def forward( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + cache_dict: torch.LongTensor = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + global p1_time, p2_time, p3_time + start = time.time() + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + p1_time += (time.time() - start) + start1 = time.time() + + prompt_embeds_origin = prompt_embeds.clone() + pooled_prompt_embeds_origin = pooled_prompt_embeds.clone() + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + global dit_time + global vae_time + global scheduler_time + + skip_flag_true = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') + skip_flag_false = torch.zeros([1], dtype=torch.long).to(f'npu:{self.device_0}') + + delta_cache = torch.zeros([2, 4096, 1536], dtype=torch.float32) + delta_cache_hidden = torch.zeros([2, 154, 1536], dtype=torch.float32) + + cache_interval = cache_dict[1] + step_contrast = cache_dict[3] % 2 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + start = time.time() + timestep_npu = t.to(torch.int64)[None].to(f'npu:{self.device_0}') + if not self.args.use_cache: + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_dict, + skip_flag_true, delta_cache, + delta_cache_hidden) + else: + if i < tgate: + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + else: + if i == tgate: + _, delta_cache = delta_cache.chunk(2) + _, delta_cache_hidden = delta_cache_hidden.chunk(2) + latent_model_input = latents + + if i < cache_dict[3]: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_dict, + skip_flag_true, delta_cache, + delta_cache_hidden) + else: + if i % cache_interval == step_contrast: + if i < tgate: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_dict, + skip_flag_true, + delta_cache, + delta_cache_hidden) + else: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_end_model, + latent_model_input, + prompt_embeds_origin, + pooled_prompt_embeds_origin, + timestep_npu, cache_dict, + skip_flag_true, + delta_cache, + delta_cache_hidden) + else: + if i < tgate: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_skip_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_dict, + skip_flag_false, + delta_cache, + delta_cache_hidden) + else: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_skip_end_model, + latent_model_input, + prompt_embeds_origin, + pooled_prompt_embeds_origin, + timestep_npu, cache_dict, + skip_flag_false, + delta_cache, + delta_cache_hidden) + + dit_time += (time.time() - start) + + # perform guidance + if self.do_classifier_free_guidance and i < tgate: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + start = time.time() + # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + step_index = torch.tensor(i).long() + latents = self.compiled_scheduler( + noise_pred.to(f'npu:{self.device_0}'), + latents.to(f'npu:{self.device_0}'), + step_index[None].to(f'npu:{self.device_0}') + ).to('cpu') + scheduler_time += (time.time() - start) + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + p2_time += time.time() - start1 + start2 = time.time() + + if output_type == "latent": + image = latents + else: + start = time.time() + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to("cpu") + vae_time += time.time() - start + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + p3_time += time.time() - start2 + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + if isinstance(args.device, list): + mindietorch.set_device(args.device[0]) + else: + mindietorch.set_device(args.device) + pipe = AIEStableDiffusion3CachePipeline.from_pretrained(args.model).to("cpu") + pipe.parser_args(args) + pipe.compile_aie_model() + + cache_dict = torch.zeros([4], dtype=torch.int64) + cache_list = args.cache_dict.split(',') + cache_dict[0] = int(cache_list[0]) + cache_dict[1] = int(cache_list[1]) + cache_dict[2] = int(cache_list[2]) + cache_dict[3] = int(cache_list[3]) + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt, + args.max_num_prompts) + + infer_num = 0 + image_info = [] + current_prompt = None + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + + start_time = time.time() + images = pipe.forward( + prompts, + negative_prompt="", + num_inference_steps=args.steps, + guidance_scale=7.0, + cache_dict=cache_dict + ) + if i > 4: # do not count the time spent inferring the first 0 to 4 images + use_time += time.time() - start_time + + for j in range(n_prompts): + image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + + image_info[-1]['images'].append(image_save_path) + + print( + f"[info] infer number: {infer_num - 5}; use time: {use_time:.3f}s\n" + f"average time: {use_time / (infer_num - 5):.3f}s\n" + f"dit time: {dit_time / infer_num:.3f}s\n" + f"scheduler_time time: {scheduler_time / infer_num:.3f}s\n" + f"vae time: {vae_time / infer_num:.3f}s\n" + f"p1 time: {p1_time / infer_num:.3f}s\n" + f"p2 time: {p2_time / infer_num:.3f}s\n" + f"p3 time: {p3_time / infer_num:.3f}s\n" + ) + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + mindietorch.finalize() + + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch new file mode 100644 index 0000000000000000000000000000000000000000..6b18fd88314d32d4bd7317b38fa6e50a5f88ace9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch @@ -0,0 +1,123 @@ +--- transformer_sd3.py 2024-09-04 09:21:58.280000000 +0000 ++++ transformer_sd3.py 2024-09-04 10:01:47.196000000 +0000 +@@ -97,6 +97,7 @@ + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.inner_dim, + context_pre_only=i == num_layers - 1, ++ layer_idx=i + ) + for i in range(self.config.num_layers) + ] +@@ -106,6 +107,8 @@ + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False ++ self.delta_cache = None ++ self.delta_cache_hidden = None + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: +@@ -245,9 +248,14 @@ + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, ++ cache_dict: torch.LongTensor = None, ++ if_skip: int = 0, ++ delta_cache: torch.FloatTensor = None, ++ delta_cache_hidden: torch.FloatTensor = None, ++ use_cache: bool = False, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ++ ): + """ + The [`SD3Transformer2DModel`] forward method. + +@@ -281,10 +289,6 @@ + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) +- else: +- logger.warning( +- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." +- ) + + height, width = hidden_states.shape[-2:] + +@@ -292,9 +296,8 @@ + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + +- for block in self.transformer_blocks: +- if self.training and self.gradient_checkpointing: +- ++ if self.training and self.gradient_checkpointing: ++ for block in self.transformer_blocks: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: +@@ -312,11 +315,14 @@ + temb, + **ckpt_kwargs, + ) +- +- else: +- encoder_hidden_states, hidden_states = block( +- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb +- ) ++ else: ++ ( ++ (encoder_hidden_states, hidden_states), ++ delta_cache, ++ delta_cache_hidden ++ ) = self.forward_blocks(hidden_states, encoder_hidden_states, temb, ++ use_cache, if_skip, cache_dict, delta_cache, ++ delta_cache_hidden) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) +@@ -339,6 +345,43 @@ + unscale_lora_layers(self, lora_scale) + + if not return_dict: +- return (output,) ++ return (output, delta_cache, delta_cache_hidden) + + return Transformer2DModelOutput(sample=output) ++ ++ def forward_blocks_range(self, hidden_states, encoder_hidden_states, temb, start_idx, end_idx): ++ for block_idx, block in enumerate(self.transformer_blocks[start_idx: end_idx]): ++ encoder_hidden_states, hidden_states = block( ++ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ++ ) ++ # self.x_out.append([torch.mean(x).to('cpu').numpy(), torch.var(x).to('cpu').numpy()]) ++ return hidden_states, encoder_hidden_states ++ ++ def forward_blocks(self, hidden_states, encoder_hidden_states, temb, use_cache, if_skip, cache_dict, delta_cache, ++ delta_cache_hidden): ++ if not use_cache: ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ 0, len(self.transformer_blocks)) ++ else: ++ # infer [0, cache_start) ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ 0, cache_dict[0]) ++ ++ # infer [cache_start, cache_end) ++ cache_end = cache_dict[0] + cache_dict[2] ++ hidden_states_before_cache = hidden_states.clone() ++ encoder_hidden_states_before_cache = encoder_hidden_states.clone() ++ if not if_skip: ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, ++ temb, cache_dict[0], ++ cache_end) ++ delta_cache = hidden_states - hidden_states_before_cache ++ delta_cache_hidden = encoder_hidden_states - encoder_hidden_states_before_cache ++ else: ++ hidden_states = hidden_states_before_cache + delta_cache ++ encoder_hidden_states = encoder_hidden_states_before_cache + delta_cache_hidden ++ ++ # infer [cache_end, len(self.blocks)) ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ cache_end, len(self.transformer_blocks)) ++ return (encoder_hidden_states, hidden_states), delta_cache, delta_cache_hidden diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..8556cbac2ba1974972fcf3965b78a9f7592b60c8 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py @@ -0,0 +1,33 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/transformers/transformer_sd3.py", + "transformer_sd3.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main()