From b03f4ca070d5560c08190e819c095a0e4ca01a7f Mon Sep 17 00:00:00 2001 From: huanghao Date: Thu, 5 Sep 2024 19:02:40 +0800 Subject: [PATCH 01/13] =?UTF-8?q?DiTCache=E7=AC=AC=E4=B8=80=E9=83=A8?= =?UTF-8?q?=E5=88=86=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../stable_diffusion_3/attention.patch | 19 +++ .../stable_diffusion_3/attention_patch.py | 4 +- .../attention_processor.patch | 80 +++++++++-- .../attention_processor_patch.py | 32 +++++ .../stable_diffusion_3/compile_model.py | 69 +++++++++ .../stable_diffusion_3/export_model.py | 135 +++++++++++++++++- .../stable_diffusion3_pipeline.py | 11 ++ .../stable_diffusion_3/transformer_sd3.patch | 123 ++++++++++++++++ .../transformer_sd3_patch.py | 33 +++++ 9 files changed, 490 insertions(+), 16 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch new file mode 100644 index 0000000000..ce183bd914 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention.patch @@ -0,0 +1,19 @@ +--- attention.py 2024-09-04 09:22:31.768000000 +0000 ++++ attention.py 2024-09-04 09:17:12.680000000 +0000 +@@ -100,7 +100,7 @@ + processing of `context` conditions. + """ + +- def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False): ++ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False, layer_idx=0): + super().__init__() + + self.context_pre_only = context_pre_only +@@ -134,6 +134,7 @@ + context_pre_only=context_pre_only, + bias=True, + processor=processor, ++ layer_idx=layer_idx + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py index 97585e6af7..2d289cb99e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py @@ -22,8 +22,8 @@ def main(): diffusers_version = diffusers.__version__ assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" - result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", - "attention_processor.patch"], capture_output=True, text=True) + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention.py", + "attention.patch"], capture_output=True, text=True) if result.returncode != 0: logging.error("Patch failed, error message: s%", result.stderr) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch index dc22a411ea..8897fabdbb 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch @@ -1,9 +1,41 @@ ---- attention_processor.py 2024-08-14 12:41:10.528000000 +0000 -+++ attention_processor.py 2024-08-14 12:42:25.620000000 +0000 -@@ -1082,6 +1082,29 @@ +--- attention_processor.py 2024-09-04 09:22:16.048000000 +0000 ++++ attention_processor.py 2024-09-04 09:57:55.928000000 +0000 +@@ -115,6 +115,7 @@ + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, ++ layer_idx: int = 0, + out_dim: int = None, + context_pre_only=None, + ): +@@ -132,13 +133,14 @@ + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only ++ self.layer_idx = layer_idx + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk +- self.scale = dim_head**-0.5 if self.scale_qk else 1.0 ++ self.scale = dim_head ** -0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation +@@ -561,6 +563,7 @@ + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, ++ layer_idx=self.layer_idx, + **cross_attention_kwargs, + ) + +@@ -1082,6 +1085,29 @@ return hidden_states - - + + +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, + scale=None) -> torch.Tensor: + # Efficient implementation equivalent to the following: @@ -20,7 +52,7 @@ + if attn_mask.dtype == torch.bool: + attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: -+ attn_bias += attn_mask ++ attn_bias = attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) @@ -29,13 +61,41 @@ + class JointAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" - -@@ -1132,7 +1155,7 @@ + +@@ -1095,6 +1121,7 @@ + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, ++ layer_idx=0, + *args, + **kwargs, + ) -> torch.FloatTensor: +@@ -1112,7 +1139,19 @@ + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. ++ # use todo + query = attn.to_q(hidden_states) ++ if layer_idx <= 11: ++ hidden_dim = hidden_states.shape[-1] ++ cur_h = int(math.sqrt(hidden_states.shape[1])) ++ cur_w = cur_h ++ hidden_states = hidden_states.transpose(1, 2).view(batch_size, hidden_dim, cur_h, cur_w) ++ new_h = int(cur_h / 2.2) ++ new_w = int(cur_w / 2.2) ++ item = F.interpolate(hidden_states, size=(new_h, new_w), ++ mode='bilinear') ++ item = item.permute(0, 2, 3, 1) ++ hidden_states = item.reshape(batch_size, new_h * new_w, -1) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + +@@ -1132,7 +1171,7 @@ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - + - hidden_states = hidden_states = F.scaled_dot_product_attention( + hidden_states = hidden_states = scaled_dot_product_attention( query, key, value, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) \ No newline at end of file + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py new file mode 100644 index 0000000000..97585e6af7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py @@ -0,0 +1,32 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", + "attention_processor.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main() diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py index b48994f182..7bd43f2136 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py @@ -18,6 +18,11 @@ from typing import List import mindietorch from mindietorch import _enums +SCHEDULER_SIGMAS = torch.tensor([1.0000, 0.9874, 0.9741, 0.9601, 0.9454, 0.9298, 0.9133, 0.8959, 0.8774, + 0.8577, 0.8367, 0.8143, 0.7904, 0.7647, 0.7371, 0.7073, 0.6751, 0.6402, + 0.6022, 0.5606, 0.5151, 0.4649, 0.4093, 0.3474, 0.2780, 0.1998, 0.1109, + 0.0089, 0.0000], dtype=torch.float32) + @dataclass class CompileParam: @@ -69,11 +74,48 @@ class VaeExport(torch.nn.Module): image = self.vae_model.decode(latents, return_dict=False)[0] return image + def compile_vae(model, inputs, vae_compiled_path, soc_version): vae_param = CompileParam(inputs, soc_version) common_compile(model, vae_compiled_path, vae_param) +class Scheduler(torch.nn.Module): + def __init__(self): + super(Scheduler, self).__init__() + self.sigmas = SCHEDULER_SIGMAS + + def forward( + self, + model_output: torch.FloatTensor, + sample: torch.FloatTensor, + step_index: torch.LongTensor + ): + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + sigma = self.sigmas[step_index] + + sigma_hat = sigma + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[step_index + 1] - sigma_hat + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + return prev_sample + + +def compile_scheduler(model, inputs, scheduler_compiled_path, soc_version): + scheduler_param = CompileParam(inputs, soc_version, True, True, False) + common_compile(model, scheduler_compiled_path, scheduler_param) + + class DiTExport(torch.nn.Module): def __init__(self, dit_model): super().__init__() @@ -93,3 +135,30 @@ class DiTExport(torch.nn.Module): def compile_dit(model, inputs, dit_compiled_path, soc_version): dit_param = CompileParam(inputs, soc_version) common_compile(model, dit_compiled_path, dit_param) + + +class DiTExportCache(torch.nn.Module): + def __init__(self, dit_cache_model): + super().__init__() + self.dit_cache_model = dit_cache_model + + def forward( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + cache_dict, + if_skip: int = 0, + delta_cache: torch.FloatTensor = None, + delta_cache_hidden: torch.FloatTensor = None, + use_cache: bool = True, + ): + return self.dit_cache_model(hidden_states, encoder_hidden_states, pooled_projections, timestep, + cache_dict, if_skip, delta_cache, delta_cache_hidden, use_cache, + joint_attention_kwargs=None, return_dict=False) + + +def compile_dit_cache(model, inputs, dit_cache_compiled_path, soc_version): + dit_cache_param = CompileParam(inputs, soc_version, True, False, True) + common_compile(model, dit_cache_compiled_path, dit_cache_param) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py index 1cda87042f..fb18b7eddc 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py @@ -17,6 +17,7 @@ import argparse from argparse import Namespace import torch +from diffusers import FlowMatchEulerDiscreteScheduler from diffusers import StableDiffusion3Pipeline import mindietorch from compile_model import * @@ -66,7 +67,7 @@ def parse_arguments() -> Namespace: ) parser.add_argument("-bs", "--batch_size", type=int, default=1, help="Batch size.") parser.add_argument("-steps", "--steps", type=int, default=28, help="steps.") - parser.add_argument("-guid", "--guidance_scale", type=float, default=5.0, help="guidance_scale") + parser.add_argument("-guid", "--guidance_scale", type=float, default=7.0, help="guidance_scale") parser.add_argument("--use_cache", action="store_true", help="Use cache during inference.") parser.add_argument("-p", "--parallel", action="store_true", help="Export the unet of bs=1 for parallel inferencing.") @@ -89,6 +90,12 @@ def parse_arguments() -> Namespace: type=int, help="image width" ) + parser.add_argument( + "--cache_dict", + type=str, + default="1,2,20,10", + help="Steps to use cache data." + ) return parser.parse_args() @@ -254,11 +261,131 @@ def export_vae(sd_pipeline, args): logging.info("vae_compiled_path already exists.") -def export(args): - pipeline = StableDiffusion3Pipeline.from_pretrained(args.model).to('cpu') +def trace_scheduler(sd_pipeline, args, scheduler_pt_path): + batch_size = args.batch_size + if not os.path.exists(scheduler_pt_path): + dummy_input = ( + torch.randn([batch_size, 16, 128, 128], dtype=torch.float32), + torch.randn([batch_size, 16, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64) + ) + scheduler = FlowMatchEulerDiscreteScheduler.from_config(sd_pipeline.scheduler.config) + scheduler.set_timesteps(args.steps, device="cpu") + + new_scheduler = Scheduler() + new_scheduler.eval() + torch.jit.trace(new_scheduler, dummy_input).save(scheduler_pt_path) + + +def export_scheduler(sd_pipeline, args): + print("Exporting the scheduler...") + scheduler_path = os.path.join(args.output_dir, "scheduler") + if not os.path.exists(scheduler_path): + os.makedirs(scheduler_path, mode=0o744) + batch_size = args.batch_size + height_size, width_size = args.height // 8, args.width // 8 + scheduler_pt_path = os.path.join(scheduler_path, f"scheduler_bs{batch_size}.pt") + scheduler_compiled_path = os.path.join(scheduler_path, + f"scheduler_bs{batch_size}_compile_{args.height}x{args.width}.ts") + in_channels = 16 + + # trace + trace_scheduler(sd_pipeline, args, scheduler_pt_path) + + # compile + if not os.path.exists(scheduler_compiled_path): + model = torch.jit.load(scheduler_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64) + ] + compile_scheduler(model, inputs, scheduler_compiled_path, args.soc) + + +def export_dit_cache(sd_pipeline, args, if_skip, flag=""): + print("Exporting the dit_cache...") + cache_dict = torch.zeros([4], dtype=torch.int64) + cache_list = args.cache_dict.split(',') + cache_dict[0] = int(cache_list[0]) + cache_dict[1] = int(cache_list[1]) + cache_dict[2] = int(cache_list[2]) + cache_dict[3] = int(cache_list[3]) + dit_path = os.path.join(args.output_dir, "dit") + if not os.path.exists(dit_path): + os.makedirs(dit_path, mode=0o640) + + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + dit_model = sd_pipeline.transformer + if args.parallel or flag == "end": + batch_size = args.batch_size + else: + batch_size = args.batch_size * 2 + sample_size = dit_model.config.sample_size + in_channels = dit_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings * 2 + dit_cache_pt_path = os.path.join(dit_path, f"dit_bs{batch_size}_{if_skip}.pt") + dit_cache_compiled_path = os.path.join(dit_path, + f"dit_bs{batch_size}_{if_skip}_compile_{args.height}x{args.width}.ts") + + # trace + if not os.path.exists(dit_cache_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([batch_size, max_position_embeddings, encoder_hidden_size * 2], dtype=torch.float32), + torch.ones([batch_size, encoder_hidden_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + cache_dict, + torch.tensor([if_skip], dtype=torch.int64), + torch.ones([batch_size, 4096, 1536], dtype=torch.float32), + torch.ones([batch_size, 154, 1536], dtype=torch.float32), + ) + print("dummy_input.shape:") + for ele in dummy_input: + if isinstance(ele, torch.Tensor): + print(ele.shape) + dit = DiTExportCache(dit_model).eval() + torch.jit.trace(dit, dummy_input).save(dit_cache_pt_path) + + # compile + if not os.path.exists(dit_cache_compiled_path): + model = torch.jit.load(dit_cache_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size * 2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((4,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, 4096, 1536), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 154, 1536), + dtype=mindietorch.dtype.FLOAT), + ] + compile_dit_cache(model, inputs, dit_cache_compiled_path, args.soc) + + +def export(args) -> None: + pipeline = StableDiffusion3Pipeline.from_pretrained(args.model).to("cpu") export_clip(pipeline, args) - export_dit(pipeline, args) + if args.use_cache: + export_dit_cache(pipeline, args, 0) + export_dit_cache(pipeline, args, 1) + if "B" in args.soc: + export_dit_cache(pipeline, args, 0, "end") + export_dit_cache(pipeline, args, 1, "end") + else: + export_dit(pipeline, args) export_vae(pipeline, args) + export_scheduler(pipeline, args) def main(args): diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py index 5415517f7e..fd2c2d66b8 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py @@ -725,6 +725,17 @@ def parse_arguments(): type=int, help="image width" ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--cache_dict", + default="1,2,20,10", + type=str, + help="steps to use cache data" + ) return parser.parse_args() diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch new file mode 100644 index 0000000000..6b18fd8831 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch @@ -0,0 +1,123 @@ +--- transformer_sd3.py 2024-09-04 09:21:58.280000000 +0000 ++++ transformer_sd3.py 2024-09-04 10:01:47.196000000 +0000 +@@ -97,6 +97,7 @@ + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.inner_dim, + context_pre_only=i == num_layers - 1, ++ layer_idx=i + ) + for i in range(self.config.num_layers) + ] +@@ -106,6 +107,8 @@ + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False ++ self.delta_cache = None ++ self.delta_cache_hidden = None + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: +@@ -245,9 +248,14 @@ + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, ++ cache_dict: torch.LongTensor = None, ++ if_skip: int = 0, ++ delta_cache: torch.FloatTensor = None, ++ delta_cache_hidden: torch.FloatTensor = None, ++ use_cache: bool = False, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ++ ): + """ + The [`SD3Transformer2DModel`] forward method. + +@@ -281,10 +289,6 @@ + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) +- else: +- logger.warning( +- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." +- ) + + height, width = hidden_states.shape[-2:] + +@@ -292,9 +296,8 @@ + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + +- for block in self.transformer_blocks: +- if self.training and self.gradient_checkpointing: +- ++ if self.training and self.gradient_checkpointing: ++ for block in self.transformer_blocks: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: +@@ -312,11 +315,14 @@ + temb, + **ckpt_kwargs, + ) +- +- else: +- encoder_hidden_states, hidden_states = block( +- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb +- ) ++ else: ++ ( ++ (encoder_hidden_states, hidden_states), ++ delta_cache, ++ delta_cache_hidden ++ ) = self.forward_blocks(hidden_states, encoder_hidden_states, temb, ++ use_cache, if_skip, cache_dict, delta_cache, ++ delta_cache_hidden) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) +@@ -339,6 +345,43 @@ + unscale_lora_layers(self, lora_scale) + + if not return_dict: +- return (output,) ++ return (output, delta_cache, delta_cache_hidden) + + return Transformer2DModelOutput(sample=output) ++ ++ def forward_blocks_range(self, hidden_states, encoder_hidden_states, temb, start_idx, end_idx): ++ for block_idx, block in enumerate(self.transformer_blocks[start_idx: end_idx]): ++ encoder_hidden_states, hidden_states = block( ++ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ++ ) ++ # self.x_out.append([torch.mean(x).to('cpu').numpy(), torch.var(x).to('cpu').numpy()]) ++ return hidden_states, encoder_hidden_states ++ ++ def forward_blocks(self, hidden_states, encoder_hidden_states, temb, use_cache, if_skip, cache_dict, delta_cache, ++ delta_cache_hidden): ++ if not use_cache: ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ 0, len(self.transformer_blocks)) ++ else: ++ # infer [0, cache_start) ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ 0, cache_dict[0]) ++ ++ # infer [cache_start, cache_end) ++ cache_end = cache_dict[0] + cache_dict[2] ++ hidden_states_before_cache = hidden_states.clone() ++ encoder_hidden_states_before_cache = encoder_hidden_states.clone() ++ if not if_skip: ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, ++ temb, cache_dict[0], ++ cache_end) ++ delta_cache = hidden_states - hidden_states_before_cache ++ delta_cache_hidden = encoder_hidden_states - encoder_hidden_states_before_cache ++ else: ++ hidden_states = hidden_states_before_cache + delta_cache ++ encoder_hidden_states = encoder_hidden_states_before_cache + delta_cache_hidden ++ ++ # infer [cache_end, len(self.blocks)) ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ cache_end, len(self.transformer_blocks)) ++ return (encoder_hidden_states, hidden_states), delta_cache, delta_cache_hidden diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py new file mode 100644 index 0000000000..8556cbac2b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py @@ -0,0 +1,33 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/transformers/transformer_sd3.py", + "transformer_sd3.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main() -- Gitee From 8fe3126739d9d7bfe59d52ab270f8de244f71b4d Mon Sep 17 00:00:00 2001 From: huanghao Date: Thu, 5 Sep 2024 19:07:39 +0800 Subject: [PATCH 02/13] =?UTF-8?q?DiTCache=E7=AC=AC=E4=B8=80=E9=83=A8?= =?UTF-8?q?=E5=88=86=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../attention_processor.patch | 101 ------------------ .../attention_processor_patch.py | 32 ------ 2 files changed, 133 deletions(-) delete mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch delete mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch deleted file mode 100644 index 8897fabdbb..0000000000 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch +++ /dev/null @@ -1,101 +0,0 @@ ---- attention_processor.py 2024-09-04 09:22:16.048000000 +0000 -+++ attention_processor.py 2024-09-04 09:57:55.928000000 +0000 -@@ -115,6 +115,7 @@ - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - processor: Optional["AttnProcessor"] = None, -+ layer_idx: int = 0, - out_dim: int = None, - context_pre_only=None, - ): -@@ -132,13 +133,14 @@ - self.fused_projections = False - self.out_dim = out_dim if out_dim is not None else query_dim - self.context_pre_only = context_pre_only -+ self.layer_idx = layer_idx - - # we make use of this private variable to know whether this class is loaded - # with an deprecated state dict so that we can convert it on the fly - self._from_deprecated_attn_block = _from_deprecated_attn_block - - self.scale_qk = scale_qk -- self.scale = dim_head**-0.5 if self.scale_qk else 1.0 -+ self.scale = dim_head ** -0.5 if self.scale_qk else 1.0 - - self.heads = out_dim // dim_head if out_dim is not None else heads - # for slice_size > 0 the attention score computation -@@ -561,6 +563,7 @@ - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, -+ layer_idx=self.layer_idx, - **cross_attention_kwargs, - ) - -@@ -1082,6 +1085,29 @@ - return hidden_states - - -+def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, -+ scale=None) -> torch.Tensor: -+ # Efficient implementation equivalent to the following: -+ L, S = query.size(-2), key.size(-2) -+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale -+ attn_bias = torch.zeros(L, S, dtype=query.dtype) -+ if is_causal: -+ assert attn_mask is None -+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) -+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) -+ attn_bias.to(query.dtype) -+ -+ if attn_mask is not None: -+ if attn_mask.dtype == torch.bool: -+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) -+ else: -+ attn_bias = attn_mask -+ attn_weight = query @ key.transpose(-2, -1) * scale_factor -+ attn_weight = torch.softmax(attn_weight, dim=-1) -+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True) -+ return attn_weight @ value -+ -+ - class JointAttnProcessor2_0: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - -@@ -1095,6 +1121,7 @@ - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, -+ layer_idx=0, - *args, - **kwargs, - ) -> torch.FloatTensor: -@@ -1112,7 +1139,19 @@ - batch_size = encoder_hidden_states.shape[0] - - # `sample` projections. -+ # use todo - query = attn.to_q(hidden_states) -+ if layer_idx <= 11: -+ hidden_dim = hidden_states.shape[-1] -+ cur_h = int(math.sqrt(hidden_states.shape[1])) -+ cur_w = cur_h -+ hidden_states = hidden_states.transpose(1, 2).view(batch_size, hidden_dim, cur_h, cur_w) -+ new_h = int(cur_h / 2.2) -+ new_w = int(cur_w / 2.2) -+ item = F.interpolate(hidden_states, size=(new_h, new_w), -+ mode='bilinear') -+ item = item.permute(0, 2, 3, 1) -+ hidden_states = item.reshape(batch_size, new_h * new_w, -1) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - -@@ -1132,7 +1171,7 @@ - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - -- hidden_states = hidden_states = F.scaled_dot_product_attention( -+ hidden_states = hidden_states = scaled_dot_product_attention( - query, key, value, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py deleted file mode 100644 index 97585e6af7..0000000000 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import subprocess -import logging -import diffusers - - -def main(): - diffusers_path = diffusers.__path__ - diffusers_version = diffusers.__version__ - - assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" - result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", - "attention_processor.patch"], capture_output=True, text=True) - if result.returncode != 0: - logging.error("Patch failed, error message: s%", result.stderr) - - -if __name__ == '__main__': - main() -- Gitee From 9fc7af8e5ff35cbe2d801a1c364c3b4a6bb8eb0d Mon Sep 17 00:00:00 2001 From: huanghao Date: Thu, 5 Sep 2024 19:09:51 +0800 Subject: [PATCH 03/13] =?UTF-8?q?DiTCache=E7=AC=AC=E4=B8=80=E9=83=A8?= =?UTF-8?q?=E5=88=86=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../attention_processor.patch | 101 ++++++++++++++ ..._patch.py => attention_processor_patch.py} | 5 +- .../stable_diffusion_3/transformer_sd3.patch | 123 ------------------ 3 files changed, 103 insertions(+), 126 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch rename MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/{transformer_sd3_patch.py => attention_processor_patch.py} (88%) delete mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch new file mode 100644 index 0000000000..8897fabdbb --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch @@ -0,0 +1,101 @@ +--- attention_processor.py 2024-09-04 09:22:16.048000000 +0000 ++++ attention_processor.py 2024-09-04 09:57:55.928000000 +0000 +@@ -115,6 +115,7 @@ + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, ++ layer_idx: int = 0, + out_dim: int = None, + context_pre_only=None, + ): +@@ -132,13 +133,14 @@ + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only ++ self.layer_idx = layer_idx + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk +- self.scale = dim_head**-0.5 if self.scale_qk else 1.0 ++ self.scale = dim_head ** -0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation +@@ -561,6 +563,7 @@ + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, ++ layer_idx=self.layer_idx, + **cross_attention_kwargs, + ) + +@@ -1082,6 +1085,29 @@ + return hidden_states + + ++def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, ++ scale=None) -> torch.Tensor: ++ # Efficient implementation equivalent to the following: ++ L, S = query.size(-2), key.size(-2) ++ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale ++ attn_bias = torch.zeros(L, S, dtype=query.dtype) ++ if is_causal: ++ assert attn_mask is None ++ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) ++ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) ++ attn_bias.to(query.dtype) ++ ++ if attn_mask is not None: ++ if attn_mask.dtype == torch.bool: ++ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) ++ else: ++ attn_bias = attn_mask ++ attn_weight = query @ key.transpose(-2, -1) * scale_factor ++ attn_weight = torch.softmax(attn_weight, dim=-1) ++ attn_weight = torch.dropout(attn_weight, dropout_p, train=True) ++ return attn_weight @ value ++ ++ + class JointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + +@@ -1095,6 +1121,7 @@ + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, ++ layer_idx=0, + *args, + **kwargs, + ) -> torch.FloatTensor: +@@ -1112,7 +1139,19 @@ + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. ++ # use todo + query = attn.to_q(hidden_states) ++ if layer_idx <= 11: ++ hidden_dim = hidden_states.shape[-1] ++ cur_h = int(math.sqrt(hidden_states.shape[1])) ++ cur_w = cur_h ++ hidden_states = hidden_states.transpose(1, 2).view(batch_size, hidden_dim, cur_h, cur_w) ++ new_h = int(cur_h / 2.2) ++ new_w = int(cur_w / 2.2) ++ item = F.interpolate(hidden_states, size=(new_h, new_w), ++ mode='bilinear') ++ item = item.permute(0, 2, 3, 1) ++ hidden_states = item.reshape(batch_size, new_h * new_w, -1) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + +@@ -1132,7 +1171,7 @@ + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + +- hidden_states = hidden_states = F.scaled_dot_product_attention( ++ hidden_states = hidden_states = scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py similarity index 88% rename from MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py rename to MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py index 8556cbac2b..97585e6af7 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor_patch.py @@ -17,14 +17,13 @@ import logging import diffusers - def main(): diffusers_path = diffusers.__path__ diffusers_version = diffusers.__version__ assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" - result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/transformers/transformer_sd3.py", - "transformer_sd3.patch"], capture_output=True, text=True) + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", + "attention_processor.patch"], capture_output=True, text=True) if result.returncode != 0: logging.error("Patch failed, error message: s%", result.stderr) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch deleted file mode 100644 index 6b18fd8831..0000000000 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch +++ /dev/null @@ -1,123 +0,0 @@ ---- transformer_sd3.py 2024-09-04 09:21:58.280000000 +0000 -+++ transformer_sd3.py 2024-09-04 10:01:47.196000000 +0000 -@@ -97,6 +97,7 @@ - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.inner_dim, - context_pre_only=i == num_layers - 1, -+ layer_idx=i - ) - for i in range(self.config.num_layers) - ] -@@ -106,6 +107,8 @@ - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) - - self.gradient_checkpointing = False -+ self.delta_cache = None -+ self.delta_cache_hidden = None - - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking - def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: -@@ -245,9 +248,14 @@ - encoder_hidden_states: torch.FloatTensor = None, - pooled_projections: torch.FloatTensor = None, - timestep: torch.LongTensor = None, -+ cache_dict: torch.LongTensor = None, -+ if_skip: int = 0, -+ delta_cache: torch.FloatTensor = None, -+ delta_cache_hidden: torch.FloatTensor = None, -+ use_cache: bool = False, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, -- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: -+ ): - """ - The [`SD3Transformer2DModel`] forward method. - -@@ -281,10 +289,6 @@ - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) -- else: -- logger.warning( -- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." -- ) - - height, width = hidden_states.shape[-2:] - -@@ -292,9 +296,8 @@ - temb = self.time_text_embed(timestep, pooled_projections) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - -- for block in self.transformer_blocks: -- if self.training and self.gradient_checkpointing: -- -+ if self.training and self.gradient_checkpointing: -+ for block in self.transformer_blocks: - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: -@@ -312,11 +315,14 @@ - temb, - **ckpt_kwargs, - ) -- -- else: -- encoder_hidden_states, hidden_states = block( -- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb -- ) -+ else: -+ ( -+ (encoder_hidden_states, hidden_states), -+ delta_cache, -+ delta_cache_hidden -+ ) = self.forward_blocks(hidden_states, encoder_hidden_states, temb, -+ use_cache, if_skip, cache_dict, delta_cache, -+ delta_cache_hidden) - - hidden_states = self.norm_out(hidden_states, temb) - hidden_states = self.proj_out(hidden_states) -@@ -339,6 +345,43 @@ - unscale_lora_layers(self, lora_scale) - - if not return_dict: -- return (output,) -+ return (output, delta_cache, delta_cache_hidden) - - return Transformer2DModelOutput(sample=output) -+ -+ def forward_blocks_range(self, hidden_states, encoder_hidden_states, temb, start_idx, end_idx): -+ for block_idx, block in enumerate(self.transformer_blocks[start_idx: end_idx]): -+ encoder_hidden_states, hidden_states = block( -+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb -+ ) -+ # self.x_out.append([torch.mean(x).to('cpu').numpy(), torch.var(x).to('cpu').numpy()]) -+ return hidden_states, encoder_hidden_states -+ -+ def forward_blocks(self, hidden_states, encoder_hidden_states, temb, use_cache, if_skip, cache_dict, delta_cache, -+ delta_cache_hidden): -+ if not use_cache: -+ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, -+ 0, len(self.transformer_blocks)) -+ else: -+ # infer [0, cache_start) -+ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, -+ 0, cache_dict[0]) -+ -+ # infer [cache_start, cache_end) -+ cache_end = cache_dict[0] + cache_dict[2] -+ hidden_states_before_cache = hidden_states.clone() -+ encoder_hidden_states_before_cache = encoder_hidden_states.clone() -+ if not if_skip: -+ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, -+ temb, cache_dict[0], -+ cache_end) -+ delta_cache = hidden_states - hidden_states_before_cache -+ delta_cache_hidden = encoder_hidden_states - encoder_hidden_states_before_cache -+ else: -+ hidden_states = hidden_states_before_cache + delta_cache -+ encoder_hidden_states = encoder_hidden_states_before_cache + delta_cache_hidden -+ -+ # infer [cache_end, len(self.blocks)) -+ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, -+ cache_end, len(self.transformer_blocks)) -+ return (encoder_hidden_states, hidden_states), delta_cache, delta_cache_hidden -- Gitee From 33b4b6461541111abe3f4a100944e53e8af9fbab Mon Sep 17 00:00:00 2001 From: huanghao Date: Thu, 5 Sep 2024 19:19:44 +0800 Subject: [PATCH 04/13] =?UTF-8?q?DiTCache=E7=AC=AC=E4=B8=80=E9=83=A8?= =?UTF-8?q?=E5=88=86=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../foundation/stable_diffusion_3/attention_processor.patch | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch index 8897fabdbb..ff9320f9ed 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch @@ -52,7 +52,7 @@ + if attn_mask.dtype == torch.bool: + attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: -+ attn_bias = attn_mask ++ attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) -- Gitee From 0f9408af46a02844f2e323c103ae126bad4e2865 Mon Sep 17 00:00:00 2001 From: huanghao Date: Thu, 5 Sep 2024 19:56:45 +0800 Subject: [PATCH 05/13] =?UTF-8?q?DiTCache=E7=AC=AC=E4=B8=80=E9=83=A8?= =?UTF-8?q?=E5=88=86=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/foundation/stable_diffusion_3/compile_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py index 7bd43f2136..749f999f76 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py @@ -18,6 +18,7 @@ from typing import List import mindietorch from mindietorch import _enums +# Scheduler coefficient, compute coefficient manually in advance to compile scheduler npu model. SCHEDULER_SIGMAS = torch.tensor([1.0000, 0.9874, 0.9741, 0.9601, 0.9454, 0.9298, 0.9133, 0.8959, 0.8774, 0.8577, 0.8367, 0.8143, 0.7904, 0.7647, 0.7371, 0.7073, 0.6751, 0.6402, 0.6022, 0.5606, 0.5151, 0.4649, 0.4093, 0.3474, 0.2780, 0.1998, 0.1109, -- Gitee From 4a6bbcf00e37f88d67975da201c6c7fca4cdd785 Mon Sep 17 00:00:00 2001 From: huanghao Date: Thu, 5 Sep 2024 20:02:00 +0800 Subject: [PATCH 06/13] =?UTF-8?q?DiTCache=E7=AC=AC=E4=B8=80=E9=83=A8?= =?UTF-8?q?=E5=88=86=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../stable_diffusion_3/compile_model.py | 4 ++-- .../stable_diffusion_3/export_model.py | 16 ++++++++-------- .../stable_diffusion3_pipeline.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py index 749f999f76..9b41ada366 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py @@ -149,14 +149,14 @@ class DiTExportCache(torch.nn.Module): encoder_hidden_states, pooled_projections, timestep, - cache_dict, + cache_param, if_skip: int = 0, delta_cache: torch.FloatTensor = None, delta_cache_hidden: torch.FloatTensor = None, use_cache: bool = True, ): return self.dit_cache_model(hidden_states, encoder_hidden_states, pooled_projections, timestep, - cache_dict, if_skip, delta_cache, delta_cache_hidden, use_cache, + cache_param, if_skip, delta_cache, delta_cache_hidden, use_cache, joint_attention_kwargs=None, return_dict=False) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py index fb18b7eddc..26f23048cf 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py @@ -91,7 +91,7 @@ def parse_arguments() -> Namespace: help="image width" ) parser.add_argument( - "--cache_dict", + "--cache_param", type=str, default="1,2,20,10", help="Steps to use cache data." @@ -307,12 +307,12 @@ def export_scheduler(sd_pipeline, args): def export_dit_cache(sd_pipeline, args, if_skip, flag=""): print("Exporting the dit_cache...") - cache_dict = torch.zeros([4], dtype=torch.int64) - cache_list = args.cache_dict.split(',') - cache_dict[0] = int(cache_list[0]) - cache_dict[1] = int(cache_list[1]) - cache_dict[2] = int(cache_list[2]) - cache_dict[3] = int(cache_list[3]) + cache_param = torch.zeros([4], dtype=torch.int64) + cache_list = args.cache_param.split(',') + cache_param[0] = int(cache_list[0]) + cache_param[1] = int(cache_list[1]) + cache_param[2] = int(cache_list[2]) + cache_param[3] = int(cache_list[3]) dit_path = os.path.join(args.output_dir, "dit") if not os.path.exists(dit_path): os.makedirs(dit_path, mode=0o640) @@ -340,7 +340,7 @@ def export_dit_cache(sd_pipeline, args, if_skip, flag=""): torch.ones([batch_size, max_position_embeddings, encoder_hidden_size * 2], dtype=torch.float32), torch.ones([batch_size, encoder_hidden_size], dtype=torch.float32), torch.ones([1], dtype=torch.int64), - cache_dict, + cache_param, torch.tensor([if_skip], dtype=torch.int64), torch.ones([batch_size, 4096, 1536], dtype=torch.float32), torch.ones([batch_size, 154, 1536], dtype=torch.float32), diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py index fd2c2d66b8..75e6f7c8de 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py @@ -731,7 +731,7 @@ def parse_arguments(): help="Use cache during inference." ) parser.add_argument( - "--cache_dict", + "--cache_param", default="1,2,20,10", type=str, help="steps to use cache data" -- Gitee From 80b5a08c7907145d514a288616564c2096e9ade4 Mon Sep 17 00:00:00 2001 From: huanghao7 Date: Thu, 5 Sep 2024 12:26:38 +0000 Subject: [PATCH 07/13] update MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py. Signed-off-by: huanghao7 --- .../built-in/foundation/stable_diffusion_3/compile_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py index 9b41ada366..4001de18cd 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py @@ -18,7 +18,8 @@ from typing import List import mindietorch from mindietorch import _enums -# Scheduler coefficient, compute coefficient manually in advance to compile scheduler npu model. +# Scheduler coefficient, compute coefficient manually in advance to compile scheduler npu model. For details, see: +# https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py SCHEDULER_SIGMAS = torch.tensor([1.0000, 0.9874, 0.9741, 0.9601, 0.9454, 0.9298, 0.9133, 0.8959, 0.8774, 0.8577, 0.8367, 0.8143, 0.7904, 0.7647, 0.7371, 0.7073, 0.6751, 0.6402, 0.6022, 0.5606, 0.5151, 0.4649, 0.4093, 0.3474, 0.2780, 0.1998, 0.1109, -- Gitee From 0530693a1a0c5f720b4aba52fddd2c364ff09dbb Mon Sep 17 00:00:00 2001 From: huanghao Date: Thu, 5 Sep 2024 20:35:54 +0800 Subject: [PATCH 08/13] =?UTF-8?q?DiTCache=E7=AC=AC2=E9=83=A8=E5=88=86?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../stable_diffusion3_pipeline_cache.py | 451 ++++++++++++++++++ 1 file changed, 451 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py new file mode 100644 index 0000000000..31ccc987cd --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -0,0 +1,451 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union +import torch +import mindietorch +from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps +from stable_diffusion3_pipeline import PromptLoader, AIEStableDiffusion3Pipeline, parse_arguments + +tgate = 20 +dit_time = 0 +vae_time = 0 +scheduler_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 + + +class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): + def compile_aie_model(self): + if self.is_init: + return + size = self.args.batch_size + batch_size = self.args.batch_size * 2 + tail = f"_{self.args.height}x{self.args.width}" + + vae_compiled_path = os.path.join(self.args.output_dir, + f"vae/vae_bs{size}_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + scheduler_compiled_path = os.path.join(self.args.output_dir, + f"scheduler/scheduler_bs{size}_compile{tail}.ts") + self.compiled_scheduler = torch.jit.load(scheduler_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, + f"clip/clip_bs{size}_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, + f"clip/clip2_bs{size}_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + t5_compiled_path = os.path.join(self.args.output_dir, + f"clip/t5_bs{size}_compile{tail}.ts") + self.compiled_t5_model = torch.jit.load(t5_compiled_path).eval() + + dit_cache_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{batch_size}_0_compile{tail}.ts") + self.compiled_dit_cache_model = torch.jit.load(dit_cache_compiled_path).eval() + + if self.args.use_cache: + dit_skip_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{batch_size}_1_compile{tail}.ts") + self.compiled_dit_skip_model = torch.jit.load(dit_skip_compiled_path).eval() + + dit_cache_end_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{size}_0_compile{tail}.ts") + self.compiled_dit_cache_end_model = torch.jit.load(dit_cache_end_compiled_path).eval() + + dit_skip_end_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{size}_1_compile{tail}.ts") + self.compiled_dit_skip_end_model = torch.jit.load(dit_skip_end_compiled_path).eval() + + self.is_init = True + + @torch.no_grad() + def dit_infer(self, compiled_model, latent_model_input, prompt_embeds, pooled_prompt_embeds, timestep_npu, + cache_param, skip_flag, delta_cache, delta_cache_hidden): + (noise_pred, delta_cache, delta_cache_hidden) = compiled_model( + latent_model_input.to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + pooled_prompt_embeds.to(f'npu:{self.device_0}'), + timestep_npu, + cache_param.to(f'npu:{self.device_0}'), + skip_flag, + delta_cache.to(f'npu:{self.device_0}'), + delta_cache_hidden.to(f'npu:{self.device_0}'), + ) + noise_pred = noise_pred.to("cpu") + delta_cache = delta_cache.to("cpu") + delta_cache_hidden = delta_cache_hidden.to("cpu") + return (noise_pred, delta_cache, delta_cache_hidden) + + @torch.no_grad() + def forward( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + cache_param: torch.LongTensor = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + global p1_time, p2_time, p3_time + start = time.time() + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + p1_time += (time.time() - start) + start1 = time.time() + + prompt_embeds_origin = prompt_embeds.clone() + pooled_prompt_embeds_origin = pooled_prompt_embeds.clone() + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + global dit_time + global vae_time + global scheduler_time + + skip_flag_true = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') + skip_flag_false = torch.zeros([1], dtype=torch.long).to(f'npu:{self.device_0}') + + delta_cache = torch.zeros([2, 4096, 1536], dtype=torch.float32) + delta_cache_hidden = torch.zeros([2, 154, 1536], dtype=torch.float32) + + cache_interval = cache_param[1] + step_contrast = cache_param[3] % 2 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + start = time.time() + timestep_npu = t.to(torch.int64)[None].to(f'npu:{self.device_0}') + if not self.args.use_cache: + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_param, + skip_flag_true, delta_cache, + delta_cache_hidden) + else: + if i < tgate: + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + else: + if i == tgate: + _, delta_cache = delta_cache.chunk(2) + _, delta_cache_hidden = delta_cache_hidden.chunk(2) + latent_model_input = latents + + if i < cache_param[3]: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_param, + skip_flag_true, delta_cache, + delta_cache_hidden) + else: + if i % cache_interval == step_contrast: + if i < tgate: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_param, + skip_flag_true, + delta_cache, + delta_cache_hidden) + else: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_end_model, + latent_model_input, + prompt_embeds_origin, + pooled_prompt_embeds_origin, + timestep_npu, cache_param, + skip_flag_true, + delta_cache, + delta_cache_hidden) + else: + if i < tgate: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_skip_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_param, + skip_flag_false, + delta_cache, + delta_cache_hidden) + else: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_skip_end_model, + latent_model_input, + prompt_embeds_origin, + pooled_prompt_embeds_origin, + timestep_npu, cache_param, + skip_flag_false, + delta_cache, + delta_cache_hidden) + + dit_time += (time.time() - start) + + # perform guidance + if self.do_classifier_free_guidance and i < tgate: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + start = time.time() + # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + step_index = torch.tensor(i).long() + latents = self.compiled_scheduler( + noise_pred.to(f'npu:{self.device_0}'), + latents.to(f'npu:{self.device_0}'), + step_index[None].to(f'npu:{self.device_0}') + ).to('cpu') + scheduler_time += (time.time() - start) + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + p2_time += time.time() - start1 + start2 = time.time() + + if output_type == "latent": + image = latents + else: + start = time.time() + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to("cpu") + vae_time += time.time() - start + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + p3_time += time.time() - start2 + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + if isinstance(args.device, list): + mindietorch.set_device(args.device[0]) + else: + mindietorch.set_device(args.device) + pipe = AIEStableDiffusion3CachePipeline.from_pretrained(args.model).to("cpu") + pipe.parser_args(args) + pipe.compile_aie_model() + + cache_param = torch.zeros([4], dtype=torch.int64) + cache_list = args.cache_param.split(',') + cache_param[0] = int(cache_list[0]) + cache_param[1] = int(cache_list[1]) + cache_param[2] = int(cache_list[2]) + cache_param[3] = int(cache_list[3]) + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt, + args.max_num_prompts) + + infer_num = 0 + image_info = [] + current_prompt = None + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + + start_time = time.time() + images = pipe.forward( + prompts, + negative_prompt="", + num_inference_steps=args.steps, + guidance_scale=7.0, + cache_param=cache_param + ) + if i > 4: # do not count the time spent inferring the first 0 to 4 images + use_time += time.time() - start_time + + for j in range(n_prompts): + image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + + image_info[-1]['images'].append(image_save_path) + + print( + f"[info] infer number: {infer_num - 5}; use time: {use_time:.3f}s\n" + f"average time: {use_time / (infer_num - 5):.3f}s\n" + f"dit time: {dit_time / infer_num:.3f}s\n" + f"scheduler_time time: {scheduler_time / infer_num:.3f}s\n" + f"vae time: {vae_time / infer_num:.3f}s\n" + f"p1 time: {p1_time / infer_num:.3f}s\n" + f"p2 time: {p2_time / infer_num:.3f}s\n" + f"p3 time: {p3_time / infer_num:.3f}s\n" + ) + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + mindietorch.finalize() + + +if __name__ == "__main__": + main() -- Gitee From 54de70c0d93eb37f7df9d21dddac263aa7bd6425 Mon Sep 17 00:00:00 2001 From: huanghao Date: Thu, 5 Sep 2024 20:39:24 +0800 Subject: [PATCH 09/13] =?UTF-8?q?DiTCache=E7=AC=AC2=E9=83=A8=E5=88=86?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../stable_diffusion3_pipeline_cache.py | 451 ------------------ 1 file changed, 451 deletions(-) delete mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py deleted file mode 100644 index 31ccc987cd..0000000000 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py +++ /dev/null @@ -1,451 +0,0 @@ -# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import time -from typing import Any, Callable, Dict, List, Optional, Union -import torch -import mindietorch -from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput -from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps -from stable_diffusion3_pipeline import PromptLoader, AIEStableDiffusion3Pipeline, parse_arguments - -tgate = 20 -dit_time = 0 -vae_time = 0 -scheduler_time = 0 -p1_time = 0 -p2_time = 0 -p3_time = 0 - - -class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): - def compile_aie_model(self): - if self.is_init: - return - size = self.args.batch_size - batch_size = self.args.batch_size * 2 - tail = f"_{self.args.height}x{self.args.width}" - - vae_compiled_path = os.path.join(self.args.output_dir, - f"vae/vae_bs{size}_compile{tail}.ts") - self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() - - scheduler_compiled_path = os.path.join(self.args.output_dir, - f"scheduler/scheduler_bs{size}_compile{tail}.ts") - self.compiled_scheduler = torch.jit.load(scheduler_compiled_path).eval() - - clip1_compiled_path = os.path.join(self.args.output_dir, - f"clip/clip_bs{size}_compile{tail}.ts") - self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() - - clip2_compiled_path = os.path.join(self.args.output_dir, - f"clip/clip2_bs{size}_compile{tail}.ts") - self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() - - t5_compiled_path = os.path.join(self.args.output_dir, - f"clip/t5_bs{size}_compile{tail}.ts") - self.compiled_t5_model = torch.jit.load(t5_compiled_path).eval() - - dit_cache_compiled_path = os.path.join(self.args.output_dir, - f"dit/dit_bs{batch_size}_0_compile{tail}.ts") - self.compiled_dit_cache_model = torch.jit.load(dit_cache_compiled_path).eval() - - if self.args.use_cache: - dit_skip_compiled_path = os.path.join(self.args.output_dir, - f"dit/dit_bs{batch_size}_1_compile{tail}.ts") - self.compiled_dit_skip_model = torch.jit.load(dit_skip_compiled_path).eval() - - dit_cache_end_compiled_path = os.path.join(self.args.output_dir, - f"dit/dit_bs{size}_0_compile{tail}.ts") - self.compiled_dit_cache_end_model = torch.jit.load(dit_cache_end_compiled_path).eval() - - dit_skip_end_compiled_path = os.path.join(self.args.output_dir, - f"dit/dit_bs{size}_1_compile{tail}.ts") - self.compiled_dit_skip_end_model = torch.jit.load(dit_skip_end_compiled_path).eval() - - self.is_init = True - - @torch.no_grad() - def dit_infer(self, compiled_model, latent_model_input, prompt_embeds, pooled_prompt_embeds, timestep_npu, - cache_param, skip_flag, delta_cache, delta_cache_hidden): - (noise_pred, delta_cache, delta_cache_hidden) = compiled_model( - latent_model_input.to(f'npu:{self.device_0}'), - prompt_embeds.to(f'npu:{self.device_0}'), - pooled_prompt_embeds.to(f'npu:{self.device_0}'), - timestep_npu, - cache_param.to(f'npu:{self.device_0}'), - skip_flag, - delta_cache.to(f'npu:{self.device_0}'), - delta_cache_hidden.to(f'npu:{self.device_0}'), - ) - noise_pred = noise_pred.to("cpu") - delta_cache = delta_cache.to("cpu") - delta_cache_hidden = delta_cache_hidden.to("cpu") - return (noise_pred, delta_cache, delta_cache_hidden) - - @torch.no_grad() - def forward( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - prompt_3: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 28, - timesteps: List[int] = None, - guidance_scale: float = 7.0, - cache_param: torch.LongTensor = None, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - negative_prompt_3: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - ): - global p1_time, p2_time, p3_time - start = time.time() - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - prompt_3, - height, - width, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - negative_prompt_3=negative_prompt_3, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - ) - - self._guidance_scale = guidance_scale - self._clip_skip = clip_skip - self._joint_attention_kwargs = joint_attention_kwargs - self._interrupt = False - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - device = self._execution_device - - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_3=prompt_3, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - negative_prompt_3=negative_prompt_3, - do_classifier_free_guidance=self.do_classifier_free_guidance, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - clip_skip=self.clip_skip, - num_images_per_prompt=num_images_per_prompt, - ) - - p1_time += (time.time() - start) - start1 = time.time() - - prompt_embeds_origin = prompt_embeds.clone() - pooled_prompt_embeds_origin = pooled_prompt_embeds.clone() - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) - - # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Denoising loop - global dit_time - global vae_time - global scheduler_time - - skip_flag_true = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') - skip_flag_false = torch.zeros([1], dtype=torch.long).to(f'npu:{self.device_0}') - - delta_cache = torch.zeros([2, 4096, 1536], dtype=torch.float32) - delta_cache_hidden = torch.zeros([2, 154, 1536], dtype=torch.float32) - - cache_interval = cache_param[1] - step_contrast = cache_param[3] % 2 - - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - start = time.time() - timestep_npu = t.to(torch.int64)[None].to(f'npu:{self.device_0}') - if not self.args.use_cache: - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( - self.compiled_dit_cache_model, - latent_model_input, - prompt_embeds, - pooled_prompt_embeds, - timestep_npu, cache_param, - skip_flag_true, delta_cache, - delta_cache_hidden) - else: - if i < tgate: - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - else: - if i == tgate: - _, delta_cache = delta_cache.chunk(2) - _, delta_cache_hidden = delta_cache_hidden.chunk(2) - latent_model_input = latents - - if i < cache_param[3]: - (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( - self.compiled_dit_cache_model, - latent_model_input, - prompt_embeds, - pooled_prompt_embeds, - timestep_npu, cache_param, - skip_flag_true, delta_cache, - delta_cache_hidden) - else: - if i % cache_interval == step_contrast: - if i < tgate: - (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( - self.compiled_dit_cache_model, - latent_model_input, - prompt_embeds, - pooled_prompt_embeds, - timestep_npu, cache_param, - skip_flag_true, - delta_cache, - delta_cache_hidden) - else: - (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( - self.compiled_dit_cache_end_model, - latent_model_input, - prompt_embeds_origin, - pooled_prompt_embeds_origin, - timestep_npu, cache_param, - skip_flag_true, - delta_cache, - delta_cache_hidden) - else: - if i < tgate: - (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( - self.compiled_dit_skip_model, - latent_model_input, - prompt_embeds, - pooled_prompt_embeds, - timestep_npu, cache_param, - skip_flag_false, - delta_cache, - delta_cache_hidden) - else: - (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( - self.compiled_dit_skip_end_model, - latent_model_input, - prompt_embeds_origin, - pooled_prompt_embeds_origin, - timestep_npu, cache_param, - skip_flag_false, - delta_cache, - delta_cache_hidden) - - dit_time += (time.time() - start) - - # perform guidance - if self.do_classifier_free_guidance and i < tgate: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - start = time.time() - # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - step_index = torch.tensor(i).long() - latents = self.compiled_scheduler( - noise_pred.to(f'npu:{self.device_0}'), - latents.to(f'npu:{self.device_0}'), - step_index[None].to(f'npu:{self.device_0}') - ).to('cpu') - scheduler_time += (time.time() - start) - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - p2_time += time.time() - start1 - start2 = time.time() - - if output_type == "latent": - image = latents - else: - start = time.time() - image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to("cpu") - vae_time += time.time() - start - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - p3_time += time.time() - start2 - - if not return_dict: - return (image,) - - return StableDiffusion3PipelineOutput(images=image) - - -def main(): - args = parse_arguments() - save_dir = args.save_dir - if not os.path.exists(save_dir): - os.makedirs(save_dir) - - if isinstance(args.device, list): - mindietorch.set_device(args.device[0]) - else: - mindietorch.set_device(args.device) - pipe = AIEStableDiffusion3CachePipeline.from_pretrained(args.model).to("cpu") - pipe.parser_args(args) - pipe.compile_aie_model() - - cache_param = torch.zeros([4], dtype=torch.int64) - cache_list = args.cache_param.split(',') - cache_param[0] = int(cache_list[0]) - cache_param[1] = int(cache_list[1]) - cache_param[2] = int(cache_list[2]) - cache_param[3] = int(cache_list[3]) - use_time = 0 - prompt_loader = PromptLoader(args.prompt_file, - args.prompt_file_type, - args.batch_size, - args.num_images_per_prompt, - args.max_num_prompts) - - infer_num = 0 - image_info = [] - current_prompt = None - for i, input_info in enumerate(prompt_loader): - prompts = input_info['prompts'] - catagories = input_info['catagories'] - save_names = input_info['save_names'] - n_prompts = input_info['n_prompts'] - - print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") - infer_num += args.batch_size - - start_time = time.time() - images = pipe.forward( - prompts, - negative_prompt="", - num_inference_steps=args.steps, - guidance_scale=7.0, - cache_param=cache_param - ) - if i > 4: # do not count the time spent inferring the first 0 to 4 images - use_time += time.time() - start_time - - for j in range(n_prompts): - image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") - image = images[0][j] - image.save(image_save_path) - - if current_prompt != prompts[j]: - current_prompt = prompts[j] - image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) - - image_info[-1]['images'].append(image_save_path) - - print( - f"[info] infer number: {infer_num - 5}; use time: {use_time:.3f}s\n" - f"average time: {use_time / (infer_num - 5):.3f}s\n" - f"dit time: {dit_time / infer_num:.3f}s\n" - f"scheduler_time time: {scheduler_time / infer_num:.3f}s\n" - f"vae time: {vae_time / infer_num:.3f}s\n" - f"p1 time: {p1_time / infer_num:.3f}s\n" - f"p2 time: {p2_time / infer_num:.3f}s\n" - f"p3 time: {p3_time / infer_num:.3f}s\n" - ) - - # Save image information to a json file - if os.path.exists(args.info_file_save_path): - os.remove(args.info_file_save_path) - - with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: - json.dump(image_info, f) - mindietorch.finalize() - - -if __name__ == "__main__": - main() -- Gitee From ae91b99560c01b8b349cf7c3f610e59733dd8bf3 Mon Sep 17 00:00:00 2001 From: huanghao Date: Thu, 5 Sep 2024 20:40:19 +0800 Subject: [PATCH 10/13] =?UTF-8?q?DiTCache=E7=AC=AC2=E9=83=A8=E5=88=86?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../stable_diffusion3_pipeline_cache.py | 451 ++++++++++++++++++ 1 file changed, 451 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py new file mode 100644 index 0000000000..31ccc987cd --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -0,0 +1,451 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union +import torch +import mindietorch +from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps +from stable_diffusion3_pipeline import PromptLoader, AIEStableDiffusion3Pipeline, parse_arguments + +tgate = 20 +dit_time = 0 +vae_time = 0 +scheduler_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 + + +class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): + def compile_aie_model(self): + if self.is_init: + return + size = self.args.batch_size + batch_size = self.args.batch_size * 2 + tail = f"_{self.args.height}x{self.args.width}" + + vae_compiled_path = os.path.join(self.args.output_dir, + f"vae/vae_bs{size}_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + scheduler_compiled_path = os.path.join(self.args.output_dir, + f"scheduler/scheduler_bs{size}_compile{tail}.ts") + self.compiled_scheduler = torch.jit.load(scheduler_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, + f"clip/clip_bs{size}_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, + f"clip/clip2_bs{size}_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + t5_compiled_path = os.path.join(self.args.output_dir, + f"clip/t5_bs{size}_compile{tail}.ts") + self.compiled_t5_model = torch.jit.load(t5_compiled_path).eval() + + dit_cache_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{batch_size}_0_compile{tail}.ts") + self.compiled_dit_cache_model = torch.jit.load(dit_cache_compiled_path).eval() + + if self.args.use_cache: + dit_skip_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{batch_size}_1_compile{tail}.ts") + self.compiled_dit_skip_model = torch.jit.load(dit_skip_compiled_path).eval() + + dit_cache_end_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{size}_0_compile{tail}.ts") + self.compiled_dit_cache_end_model = torch.jit.load(dit_cache_end_compiled_path).eval() + + dit_skip_end_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{size}_1_compile{tail}.ts") + self.compiled_dit_skip_end_model = torch.jit.load(dit_skip_end_compiled_path).eval() + + self.is_init = True + + @torch.no_grad() + def dit_infer(self, compiled_model, latent_model_input, prompt_embeds, pooled_prompt_embeds, timestep_npu, + cache_param, skip_flag, delta_cache, delta_cache_hidden): + (noise_pred, delta_cache, delta_cache_hidden) = compiled_model( + latent_model_input.to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + pooled_prompt_embeds.to(f'npu:{self.device_0}'), + timestep_npu, + cache_param.to(f'npu:{self.device_0}'), + skip_flag, + delta_cache.to(f'npu:{self.device_0}'), + delta_cache_hidden.to(f'npu:{self.device_0}'), + ) + noise_pred = noise_pred.to("cpu") + delta_cache = delta_cache.to("cpu") + delta_cache_hidden = delta_cache_hidden.to("cpu") + return (noise_pred, delta_cache, delta_cache_hidden) + + @torch.no_grad() + def forward( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + cache_param: torch.LongTensor = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + global p1_time, p2_time, p3_time + start = time.time() + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + p1_time += (time.time() - start) + start1 = time.time() + + prompt_embeds_origin = prompt_embeds.clone() + pooled_prompt_embeds_origin = pooled_prompt_embeds.clone() + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + global dit_time + global vae_time + global scheduler_time + + skip_flag_true = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') + skip_flag_false = torch.zeros([1], dtype=torch.long).to(f'npu:{self.device_0}') + + delta_cache = torch.zeros([2, 4096, 1536], dtype=torch.float32) + delta_cache_hidden = torch.zeros([2, 154, 1536], dtype=torch.float32) + + cache_interval = cache_param[1] + step_contrast = cache_param[3] % 2 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + start = time.time() + timestep_npu = t.to(torch.int64)[None].to(f'npu:{self.device_0}') + if not self.args.use_cache: + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_param, + skip_flag_true, delta_cache, + delta_cache_hidden) + else: + if i < tgate: + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + else: + if i == tgate: + _, delta_cache = delta_cache.chunk(2) + _, delta_cache_hidden = delta_cache_hidden.chunk(2) + latent_model_input = latents + + if i < cache_param[3]: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_param, + skip_flag_true, delta_cache, + delta_cache_hidden) + else: + if i % cache_interval == step_contrast: + if i < tgate: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_param, + skip_flag_true, + delta_cache, + delta_cache_hidden) + else: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_end_model, + latent_model_input, + prompt_embeds_origin, + pooled_prompt_embeds_origin, + timestep_npu, cache_param, + skip_flag_true, + delta_cache, + delta_cache_hidden) + else: + if i < tgate: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_skip_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_param, + skip_flag_false, + delta_cache, + delta_cache_hidden) + else: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_skip_end_model, + latent_model_input, + prompt_embeds_origin, + pooled_prompt_embeds_origin, + timestep_npu, cache_param, + skip_flag_false, + delta_cache, + delta_cache_hidden) + + dit_time += (time.time() - start) + + # perform guidance + if self.do_classifier_free_guidance and i < tgate: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + start = time.time() + # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + step_index = torch.tensor(i).long() + latents = self.compiled_scheduler( + noise_pred.to(f'npu:{self.device_0}'), + latents.to(f'npu:{self.device_0}'), + step_index[None].to(f'npu:{self.device_0}') + ).to('cpu') + scheduler_time += (time.time() - start) + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + p2_time += time.time() - start1 + start2 = time.time() + + if output_type == "latent": + image = latents + else: + start = time.time() + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to("cpu") + vae_time += time.time() - start + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + p3_time += time.time() - start2 + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + if isinstance(args.device, list): + mindietorch.set_device(args.device[0]) + else: + mindietorch.set_device(args.device) + pipe = AIEStableDiffusion3CachePipeline.from_pretrained(args.model).to("cpu") + pipe.parser_args(args) + pipe.compile_aie_model() + + cache_param = torch.zeros([4], dtype=torch.int64) + cache_list = args.cache_param.split(',') + cache_param[0] = int(cache_list[0]) + cache_param[1] = int(cache_list[1]) + cache_param[2] = int(cache_list[2]) + cache_param[3] = int(cache_list[3]) + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt, + args.max_num_prompts) + + infer_num = 0 + image_info = [] + current_prompt = None + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + + start_time = time.time() + images = pipe.forward( + prompts, + negative_prompt="", + num_inference_steps=args.steps, + guidance_scale=7.0, + cache_param=cache_param + ) + if i > 4: # do not count the time spent inferring the first 0 to 4 images + use_time += time.time() - start_time + + for j in range(n_prompts): + image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + + image_info[-1]['images'].append(image_save_path) + + print( + f"[info] infer number: {infer_num - 5}; use time: {use_time:.3f}s\n" + f"average time: {use_time / (infer_num - 5):.3f}s\n" + f"dit time: {dit_time / infer_num:.3f}s\n" + f"scheduler_time time: {scheduler_time / infer_num:.3f}s\n" + f"vae time: {vae_time / infer_num:.3f}s\n" + f"p1 time: {p1_time / infer_num:.3f}s\n" + f"p2 time: {p2_time / infer_num:.3f}s\n" + f"p3 time: {p3_time / infer_num:.3f}s\n" + ) + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + mindietorch.finalize() + + +if __name__ == "__main__": + main() -- Gitee From ff0a848c99fc8020b6e4eac8e31bdde14adc2a82 Mon Sep 17 00:00:00 2001 From: huanghao Date: Fri, 6 Sep 2024 11:48:05 +0800 Subject: [PATCH 11/13] =?UTF-8?q?DiTCache=E7=AC=AC3=E9=83=A8=E5=88=86?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../foundation/stable_diffusion_3/README.md | 30 +++- .../stable_diffusion3_pipeline_cache.py | 129 +++++++++++++++++- .../stable_diffusion_3/transformer_sd3.patch | 123 +++++++++++++++++ .../transformer_sd3_patch.py | 33 +++++ 4 files changed, 310 insertions(+), 5 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md index 4a807769b9..58089608ff 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md @@ -67,12 +67,14 @@ ```bash # 安装mindie + source /usr/local/Ascend/ascend-toolkit/set_env.sh chmod +x ./Ascend-mindie_xxx.run ./Ascend-mindie_xxx.run --install source /usr/local/Ascend/mindie/set_env.sh ``` -3. 代码修改 +3. 代码修改(可选) +(1)若需要开启DiTCache、序列压缩等优化,需要执行以下代码修改操作: - 若环境没有patch工具,请自行安装: ```bash apt update @@ -81,6 +83,8 @@ - 执行命令: ```bash python3 attention_patch.py + python3 attention_processor_patch.py + python3 transformer_sd3_patch.py ``` ## 模型推理 @@ -117,8 +121,10 @@ (3) 执行export命令 ```bash - # 800I A2,非并行 + # 800I A2,非并行,未加DiTCache优化 python3 export_model.py --model ${model_base} --output_dir ./models --batch_size 1 --soc Ascend910B4 --device 0 + # 800I A2,非并行。开启DiTCache优化 + python3 export_model.py --model ${model_base} --output_dir ./models --batch_size 1 --soc Ascend910B4 --device 0 --use_cache # 300I Duo,并行 python3 export_model.py --model ${model_base} --output_dir ./models --parallel --batch_size 1 --soc Ascend310P3 --device 0 @@ -130,6 +136,7 @@ - --batch_size: 设置batch_size, 默认值为1, 当前仅支持batch_size=1的场景 - --soc:只支持Ascend910B4和Ascend310P3。默认为Ascend910B4。 - --device:推理设备ID + - --use_cache:开启DiTCache优化,不配置则不开启 注意:trace+compile耗时较长且占用较多的CPU资源,请勿在执行export命令时运行其他占用CPU内存的任务,避免程序意外退出。 2. 开始推理验证。 @@ -163,7 +170,7 @@ 3. 执行推理脚本。 ```bash - # 不使用unetCache策略 + # 不使用DiTCache,单卡推理,适用800I A2场景 numactl -C 0-23 python3 stable_diffusion3_pipeline.py \ --model ${model_base} \ --prompt_file ./prompts.txt \ @@ -176,7 +183,7 @@ --width 1024 \ --batch_size 1 - # 使用UnetCache策略,同时使用双卡并行策略 + # 不使用DiTCache,使用双卡并行推理,适用300I DUO场景 numactl -C 0-23 python3 stable_diffusion3_pipeline.py \ --model ${model_base} \ --prompt_file ./prompts.txt \ @@ -189,6 +196,20 @@ --width 1024 \ --batch_size 1 \ --parallel + + # 使用DiTCache,单卡推理,适用800I A2场景 + numactl -C 0-23 python3 stable_diffusion3_pipeline_cache.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --prompt_file_type plain \ + --device 0 \ + --save_dir ./results \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 \ + --use_cache ``` 参数说明: @@ -202,6 +223,7 @@ - --device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 - --height:生成图像高度,当前只支持1024 - --width:生成图像宽度,当前只支持1024 + - --use_cache:开启DiTCache优化,不配置则不开启 非并行策略,执行完成后在`./results`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 并行策略,同时使用双卡并行策略,执行完成后在`./results_parallel`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py index 31ccc987cd..5f980cb366 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import json import os import time @@ -20,7 +21,7 @@ import torch import mindietorch from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps -from stable_diffusion3_pipeline import PromptLoader, AIEStableDiffusion3Pipeline, parse_arguments +from stable_diffusion3_pipeline import PromptLoader, AIEStableDiffusion3Pipeline tgate = 20 dit_time = 0 @@ -366,6 +367,132 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): return StableDiffusion3PipelineOutput(images=image) +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-3-medium-diffusers", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="A text file of prompts for generating images.", + ) + parser.add_argument( + "--prompt_file_type", + choices=["plain", "parti", "hpsv2"], + default="plain", + help="Type of prompt file.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--info_file_save_path", + type=str, + default="./image_info.json", + help="Path to save image information file.", + ) + parser.add_argument( + "--steps", + type=int, + default=28, + help="Number of inference steps.", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=0, + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "--num_images_per_prompt", + default=1, + type=int, + help="Number of images generated for each prompt.", + ) + parser.add_argument( + "--max_num_prompts", + default=0, + type=int, + help="Limit the number of prompts (0: no limit).", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) + parser.add_argument( + "--scheduler", + choices=["FlowMatchEuler"], + default="FlowMatchEuler", + help="Type of Sampling methods. Default FlowMatchEuler", + ) + parser.add_argument( + "--height", + default=1024, + type=int, + help="image height", + ) + parser.add_argument( + "--width", + default=1024, + type=int, + help="image width" + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--cache_param", + default="1,2,20,10", + type=str, + help="steps to use cache data" + ) + + return parser.parse_args() + + def main(): args = parse_arguments() save_dir = args.save_dir diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch new file mode 100644 index 0000000000..6b18fd8831 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch @@ -0,0 +1,123 @@ +--- transformer_sd3.py 2024-09-04 09:21:58.280000000 +0000 ++++ transformer_sd3.py 2024-09-04 10:01:47.196000000 +0000 +@@ -97,6 +97,7 @@ + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.inner_dim, + context_pre_only=i == num_layers - 1, ++ layer_idx=i + ) + for i in range(self.config.num_layers) + ] +@@ -106,6 +107,8 @@ + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False ++ self.delta_cache = None ++ self.delta_cache_hidden = None + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: +@@ -245,9 +248,14 @@ + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, ++ cache_dict: torch.LongTensor = None, ++ if_skip: int = 0, ++ delta_cache: torch.FloatTensor = None, ++ delta_cache_hidden: torch.FloatTensor = None, ++ use_cache: bool = False, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ++ ): + """ + The [`SD3Transformer2DModel`] forward method. + +@@ -281,10 +289,6 @@ + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) +- else: +- logger.warning( +- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." +- ) + + height, width = hidden_states.shape[-2:] + +@@ -292,9 +296,8 @@ + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + +- for block in self.transformer_blocks: +- if self.training and self.gradient_checkpointing: +- ++ if self.training and self.gradient_checkpointing: ++ for block in self.transformer_blocks: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: +@@ -312,11 +315,14 @@ + temb, + **ckpt_kwargs, + ) +- +- else: +- encoder_hidden_states, hidden_states = block( +- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb +- ) ++ else: ++ ( ++ (encoder_hidden_states, hidden_states), ++ delta_cache, ++ delta_cache_hidden ++ ) = self.forward_blocks(hidden_states, encoder_hidden_states, temb, ++ use_cache, if_skip, cache_dict, delta_cache, ++ delta_cache_hidden) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) +@@ -339,6 +345,43 @@ + unscale_lora_layers(self, lora_scale) + + if not return_dict: +- return (output,) ++ return (output, delta_cache, delta_cache_hidden) + + return Transformer2DModelOutput(sample=output) ++ ++ def forward_blocks_range(self, hidden_states, encoder_hidden_states, temb, start_idx, end_idx): ++ for block_idx, block in enumerate(self.transformer_blocks[start_idx: end_idx]): ++ encoder_hidden_states, hidden_states = block( ++ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ++ ) ++ # self.x_out.append([torch.mean(x).to('cpu').numpy(), torch.var(x).to('cpu').numpy()]) ++ return hidden_states, encoder_hidden_states ++ ++ def forward_blocks(self, hidden_states, encoder_hidden_states, temb, use_cache, if_skip, cache_dict, delta_cache, ++ delta_cache_hidden): ++ if not use_cache: ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ 0, len(self.transformer_blocks)) ++ else: ++ # infer [0, cache_start) ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ 0, cache_dict[0]) ++ ++ # infer [cache_start, cache_end) ++ cache_end = cache_dict[0] + cache_dict[2] ++ hidden_states_before_cache = hidden_states.clone() ++ encoder_hidden_states_before_cache = encoder_hidden_states.clone() ++ if not if_skip: ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, ++ temb, cache_dict[0], ++ cache_end) ++ delta_cache = hidden_states - hidden_states_before_cache ++ delta_cache_hidden = encoder_hidden_states - encoder_hidden_states_before_cache ++ else: ++ hidden_states = hidden_states_before_cache + delta_cache ++ encoder_hidden_states = encoder_hidden_states_before_cache + delta_cache_hidden ++ ++ # infer [cache_end, len(self.blocks)) ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ cache_end, len(self.transformer_blocks)) ++ return (encoder_hidden_states, hidden_states), delta_cache, delta_cache_hidden diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py new file mode 100644 index 0000000000..8556cbac2b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py @@ -0,0 +1,33 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/transformers/transformer_sd3.py", + "transformer_sd3.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main() -- Gitee From 83ce2a6c523d112dd988a939f52089ea261b6126 Mon Sep 17 00:00:00 2001 From: huanghao Date: Fri, 6 Sep 2024 16:57:13 +0800 Subject: [PATCH 12/13] =?UTF-8?q?DiTCache=E7=AC=AC3=E9=83=A8=E5=88=86?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../foundation/stable_diffusion_3/README.md | 27 +++++++++++++++---- .../stable_diffusion_3/export_model.py | 3 ++- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md index 58089608ff..ab02527595 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md @@ -117,24 +117,41 @@ ```bash mkdir ./models ``` + (3) 执行命令查看芯片名称($\{chip\_name\})。 - (3) 执行export命令 + ``` + npu-smi info + #该设备芯片chip_name=310P3 (自行替换) + 回显如下: + +-------------------+-----------------+------------------------------------------------------+ + | NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page) | + | Chip Device | Bus-Id | AICore(%) Memory-Usage(MB) | + +===================+=================+======================================================+ + | 0 310P3 | OK | 15.8 42 0 / 0 | + | 0 0 | 0000:82:00.0 | 0 1074 / 21534 | + +===================+=================+======================================================+ + | 1 310P3 | OK | 15.4 43 0 / 0 | + | 0 1 | 0000:89:00.0 | 0 1070 / 21534 | + +===================+=================+======================================================+ + ``` + + (4) 执行export命令 ```bash # 800I A2,非并行,未加DiTCache优化 - python3 export_model.py --model ${model_base} --output_dir ./models --batch_size 1 --soc Ascend910B4 --device 0 + python3 export_model.py --model ${model_base} --output_dir ./models --batch_size 1 --soc Ascend${chip_name} --device_type A2 --device 0 # 800I A2,非并行。开启DiTCache优化 - python3 export_model.py --model ${model_base} --output_dir ./models --batch_size 1 --soc Ascend910B4 --device 0 --use_cache + python3 export_model.py --model ${model_base} --output_dir ./models --batch_size 1 --soc Ascend${chip_name} --device_type A2 --device 0 --use_cache # 300I Duo,并行 - python3 export_model.py --model ${model_base} --output_dir ./models --parallel --batch_size 1 --soc Ascend310P3 --device 0 + python3 export_model.py --model ${model_base} --output_dir ./models --parallel --batch_size 1 --soc Ascend${chip_name} --device_type Duo --device 0 ``` 参数说明: - --model:模型权重路径 - --output_dir: 存放导出模型的路径 - --parallel: 【可选】导出适用于并行方案的模型 - --batch_size: 设置batch_size, 默认值为1, 当前仅支持batch_size=1的场景 - - --soc:只支持Ascend910B4和Ascend310P3。默认为Ascend910B4。 + - --soc:处理器型号。 - --device:推理设备ID - --use_cache:开启DiTCache优化,不配置则不开启 注意:trace+compile耗时较长且占用较多的CPU资源,请勿在执行export命令时运行其他占用CPU内存的任务,避免程序意外退出。 diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py index 26f23048cf..ea11642ab6 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py @@ -72,6 +72,7 @@ def parse_arguments() -> Namespace: parser.add_argument("-p", "--parallel", action="store_true", help="Export the unet of bs=1 for parallel inferencing.") parser.add_argument("--soc", help="soc_version.") + parser.add_argument("--device_type", choices=["A2", "Duo"], default="A2", help="device type.") parser.add_argument( "--device", default=0, @@ -379,7 +380,7 @@ def export(args) -> None: if args.use_cache: export_dit_cache(pipeline, args, 0) export_dit_cache(pipeline, args, 1) - if "B" in args.soc: + if args.device_type == "A2": export_dit_cache(pipeline, args, 0, "end") export_dit_cache(pipeline, args, 1, "end") else: -- Gitee From df5ad76d79d126bc5df8a6aea83db1af7dacc7d8 Mon Sep 17 00:00:00 2001 From: huanghao Date: Fri, 6 Sep 2024 17:21:14 +0800 Subject: [PATCH 13/13] =?UTF-8?q?DiTCache=E7=AC=AC3=E9=83=A8=E5=88=86?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/foundation/stable_diffusion_3/README.md | 1 + .../foundation/stable_diffusion_3/transformer_sd3.patch | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md index ab02527595..c02b45094d 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md @@ -152,6 +152,7 @@ - --parallel: 【可选】导出适用于并行方案的模型 - --batch_size: 设置batch_size, 默认值为1, 当前仅支持batch_size=1的场景 - --soc:处理器型号。 + - --device_type: 设备形态,当前支持A2、Duo两种形态。 - --device:推理设备ID - --use_cache:开启DiTCache优化,不配置则不开启 注意:trace+compile耗时较长且占用较多的CPU资源,请勿在执行export命令时运行其他占用CPU内存的任务,避免程序意外退出。 diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch index 6b18fd8831..bdbbb9c867 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch @@ -90,7 +90,7 @@ + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) -+ # self.x_out.append([torch.mean(x).to('cpu').numpy(), torch.var(x).to('cpu').numpy()]) ++ + return hidden_states, encoder_hidden_states + + def forward_blocks(self, hidden_states, encoder_hidden_states, temb, use_cache, if_skip, cache_dict, delta_cache, -- Gitee