diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..97585e6af71fd0df9f9af7ba99eff923e3352b15 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py @@ -0,0 +1,32 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", + "attention_processor.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main() diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch new file mode 100644 index 0000000000000000000000000000000000000000..dc22a411eaf2ee4e13f6ea67a13981819e5c559a --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch @@ -0,0 +1,41 @@ +--- attention_processor.py 2024-08-14 12:41:10.528000000 +0000 ++++ attention_processor.py 2024-08-14 12:42:25.620000000 +0000 +@@ -1082,6 +1082,29 @@ + return hidden_states + + ++def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, ++ scale=None) -> torch.Tensor: ++ # Efficient implementation equivalent to the following: ++ L, S = query.size(-2), key.size(-2) ++ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale ++ attn_bias = torch.zeros(L, S, dtype=query.dtype) ++ if is_causal: ++ assert attn_mask is None ++ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) ++ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) ++ attn_bias.to(query.dtype) ++ ++ if attn_mask is not None: ++ if attn_mask.dtype == torch.bool: ++ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) ++ else: ++ attn_bias += attn_mask ++ attn_weight = query @ key.transpose(-2, -1) * scale_factor ++ attn_weight = torch.softmax(attn_weight, dim=-1) ++ attn_weight = torch.dropout(attn_weight, dropout_p, train=True) ++ return attn_weight @ value ++ ++ + class JointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + +@@ -1132,7 +1155,7 @@ + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + +- hidden_states = hidden_states = F.scaled_dot_product_attention( ++ hidden_states = hidden_states = scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/background_runtime.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/background_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4935af2d4301cf30f38bcab89e08e4a70fad9d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/background_runtime.py @@ -0,0 +1,191 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import time +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +import mindietorch + + +NUM_LAYERS = 28 + + +@dataclass +class RuntimeIOInfo: + input_shapes: List[tuple] + input_dtypes: List[type] + output_shapes: List[tuple] + output_dtypes: List[type] + + +class BackgroundRuntime: + def __init__( + self, + device_id: int, + model_path: str, + io_info: RuntimeIOInfo + ): + # Create a pipe for process synchronization + self.sync_pipe, sync_pipe_peer = mp.Pipe(duplex=True) + + # Create shared buffers + input_spaces = self.create_shared_buffers(io_info.input_shapes, + io_info.input_dtypes) + output_spaces = self.create_shared_buffers(io_info.output_shapes, + io_info.output_dtypes) + + # Build numpy arrays on the shared buffers + self.input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + self.output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + mp.set_start_method('fork', force=True) + self.p = mp.Process(target=self.run_infer, + args=[ + sync_pipe_peer, input_spaces, output_spaces, + io_info, device_id, model_path + ]) + self.p.start() + + # Wait until the sub process is ready + self.wait() + + @staticmethod + def create_shared_buffers(shapes: List[tuple], + dtypes: List[type]) -> List[mp.RawArray]: + buffers = [] + for shape, dtype in zip(shapes, dtypes): + size = 1 + for x in shape: + size *= x + + raw_array = mp.RawArray(np.ctypeslib.as_ctypes_type(dtype), size) + buffers.append(raw_array) + + return buffers + + def infer_asyn(self, feeds: List[np.ndarray]) -> None: + for i, _ in enumerate(self.input_arrays): + self.input_arrays[i][:] = feeds[i][:] + + self.sync_pipe.send('') + + def wait(self) -> None: + self.sync_pipe.recv() + + def get_outputs(self) -> List[np.ndarray]: + return self.output_arrays + + def wait_and_get_outputs(self) -> List[np.ndarray]: + self.wait() + return self.get_outputs() + + def infer(self, feeds: List[np.ndarray]) -> List[np.ndarray]: + self.infer_asyn(feeds) + return self.wait_and_get_outputs() + + def stop(self): + # Stop the sub process + self.sync_pipe.send('STOP') + + @staticmethod + def run_infer( + sync_pipe: mp.connection.Connection, + input_spaces: List[np.ndarray], + output_spaces: List[np.ndarray], + io_info: RuntimeIOInfo, + device_id: int, + model_path: str, + ) -> None: + # The sub process function + # Create a runtime + print(f"[info] bg device id: {device_id}") + + # Tell the main function that we are ready + model = torch.jit.load(model_path).eval() + + # Build numpy arrays on the shared buffers + input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + + output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + mindietorch.set_device(device_id) + + # Tell the main function that we are ready + sync_pipe.send('') + + infer_num = 0 + preprocess_time = 0 + infer_time = 0 + forward_time = 0 + + stream = mindietorch.npu.Stream(f"npu:{device_id}") + + # Keep looping until recived a 'STOP' + while sync_pipe.recv() != 'STOP': + start = time.time() + hidden_states, encoder_hidden_states, pooled_projections, timestep = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + hidden_states_npu = hidden_states.to(torch.float32).to(f"npu:{device_id}") + encoder_hidden_states_npu = encoder_hidden_states.to(torch.float32).to(f"npu:{device_id}") + pooled_projections_npu = pooled_projections.to(torch.float32).to(f"npu:{device_id}") + timestep_npu = timestep.to(torch.int64).to(f"npu:{device_id}") + + preprocess_time += time.time() - start + + start2 = time.time() + with mindietorch.npu.stream(stream): + inf_start = time.time() + output_npu = model( + hidden_states_npu, + encoder_hidden_states_npu, + pooled_projections_npu, + timestep_npu + ) + stream.synchronize() + inf_end = time.time() + + output_cpu = output_npu.to('cpu') + forward_time += inf_end - inf_start + infer_time += time.time() - start2 + + for i, _ in enumerate(output_arrays): + output = output_cpu.numpy() + output_arrays[i][:] = output[i][:] + + infer_num += 1 + sync_pipe.send('') + + infer_num /= NUM_LAYERS + + @classmethod + def clone(cls, device_id: int, model_path: str, runtime_info: RuntimeIOInfo) -> 'BackgroundRuntime': + # Get shapes, datatypes from an existed engine, + # then use them to create a BackgroundRuntime + return cls(device_id, model_path, runtime_info) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3608813b1e1977445f5cf636bfe50803793ed91b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py @@ -0,0 +1,273 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import argparse +from argparse import Namespace + +import torch +from diffusers import StableDiffusion3Pipeline +import mindietorch +from compile_model import * + + +def check_owner(path: str): + """ + check the path owner + param: the input path + return: whether the path owner is current user or not + """ + path_stat = os.stat(path) + path_owner, path_gid = path_stat.st_uid, path_stat.st_gid + user_check = path_owner == os.getuid() and path_owner == os.geteuid() + return path_owner == 0 or path_gid in os.getgroups() or user_check + + +def path_check(path: str): + """ + check path + param: path + return: data real path after check + """ + if os.path.islink(path) or path is None: + raise RuntimeError("The path should not be None or a symbolic link file.") + path = os.path.realpath(path) + if not check_owner(path): + raise RuntimeError("The path is not owned by current user or root.") + return path + + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt models.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-3-medium-diffusers", + help="Path or name of the pre-trained model.", + ) + parser.add_argument("-bs", "--batch_size", type=int, default=1, help="Batch size.") + parser.add_argument("-steps", "--steps", type=int, default=28, help="steps.") + parser.add_argument("-guid", "--guidance_scale", type=float, default=5.0, help="guidance_scale") + parser.add_argument("--use_cache", action="store_true", help="Use cache during inference.") + parser.add_argument("-p", "--parallel", action="store_true", + help="Export the unet of bs=1 for parallel inferencing.") + parser.add_argument("--soc", choices=["Duo", "A2"], default="A2", help="soc_version.") + parser.add_argument( + "--device", + default=0, + type=int, + help="NPU device", + ) + parser.add_argument( + "--height", + default=1024, + type=int, + help="image height", + ) + parser.add_argument( + "--width", + default=1024, + type=int, + help="image width" + ) + return parser.parse_args() + + +def trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path, t5_pt_path): + encoder_model = sd_pipeline.text_encoder + encoder_2_model = sd_pipeline.text_encoder_2 + t5_model = sd_pipeline.text_encoder_3 + max_position_embeddings = encoder_model.config.max_position_embeddings + dummy_input = torch.ones([batch_size, max_position_embeddings], dtype=torch.int64) + + if not os.path.exists(clip_pt_path): + clip_export = ClipExport(encoder_model) + torch.jit.trace(clip_export, dummy_input).save(clip_pt_path) + else: + logging.info("clip_pt_path already exists.") + + if not os.path.exists(clip2_pt_path): + clip2_export = ClipExport(encoder_2_model) + torch.jit.trace(clip2_export, dummy_input).save(clip2_pt_path) + else: + logging.info("clip2_pt_path already exists.") + + if not os.path.exists(t5_pt_path): + t5_export = ClipExport(t5_model) + torch.jit.trace(t5_export, dummy_input).save(t5_pt_path) + else: + logging.info("t5_pt_path already exists.") + + +def export_clip(sd_pipeline, args): + print("Exporting the text encoder...") + standard_path = path_check(args.output_dir) + clip_path = os.path.join(standard_path, "clip") + if not os.path.exists(clip_path): + os.makedirs(clip_path, mode=0o640) + batch_size = args.batch_size + clip_pt_path = os.path.join(clip_path, f"clip_bs{batch_size}.pt") + clip2_pt_path = os.path.join(clip_path, f"clip2_bs{batch_size}.pt") + t5_pt_path = os.path.join(clip_path, f"t5_bs{batch_size}.pt") + clip1_compiled_path = os.path.join(clip_path, + f"clip_bs{batch_size}_compile_{args.height}x{args.width}.ts") + clip2_compiled_path = os.path.join(clip_path, + f"clip2_bs{batch_size}_compile_{args.height}x{args.width}.ts") + t5_compiled_path = os.path.join(clip_path, + f"t5_bs{batch_size}_compile_{args.height}x{args.width}.ts") + + encoder_model = sd_pipeline.text_encoder + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path, t5_pt_path) + + # compile + if not os.path.exists(clip1_compiled_path): + model = torch.jit.load(clip_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip1_compiled_path, args.soc) + else: + logging.info("clip1_compiled_path already exists.") + if not os.path.exists(clip2_compiled_path): + model = torch.jit.load(clip2_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip2_compiled_path, args.soc) + else: + logging.info("clip2_compiled_path already exists.") + if not os.path.exists(t5_compiled_path): + model = torch.jit.load(t5_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, t5_compiled_path, args.soc) + else: + logging.info("t5_compiled_path already exists.") + + +def export_dit(sd_pipeline, args): + print("Exporting the dit...") + standard_path = path_check(args.output_dir) + dit_path = os.path.join(standard_path, "dit") + if not os.path.exists(dit_path): + os.makedirs(dit_path, mode=0o640) + + dit_model = sd_pipeline.transformer + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + if not args.parallel: + batch_size = args.batch_size * 2 + else: + batch_size = args.batch_size + sample_size = dit_model.config.sample_size + in_channels = dit_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings * 2 + + dit_pt_path = os.path.join(dit_path, f"dit_bs{batch_size}.pt") + dit_compiled_path = os.path.join(dit_path, + f"dit_bs{batch_size}_compile_{args.height}x{args.width}.ts") + + # trace + if not os.path.exists(dit_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size * 2], dtype=torch.float32 + ), + torch.ones([batch_size, encoder_hidden_size], dtype=torch.float32), + torch.ones([batch_size], dtype=torch.int64) + ) + dit = DiTExport(dit_model).eval() + torch.jit.trace(dit, dummy_input).save(dit_pt_path) + else: + logging.info("dit_pt_path already exists.") + + # compile + if not os.path.exists(dit_compiled_path): + model = torch.jit.load(dit_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size * 2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size,), dtype=mindietorch.dtype.INT64)] + compile_dit(model, inputs, dit_compiled_path, args.soc) + else: + logging.info("dit_compiled_path already exists.") + + +def export_vae(sd_pipeline, args): + print("Exporting the image decoder...") + standard_path = path_check(args.output_dir) + vae_path = os.path.join(standard_path, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + batch_size = args.batch_size + vae_pt_path = os.path.join(vae_path, f"vae_bs{batch_size}.pt") + vae_compiled_path = os.path.join(vae_path, + f"vae_bs{batch_size}_compile_{args.height}x{args.width}.ts") + + vae_model = sd_pipeline.vae + dit_model = sd_pipeline.transformer + scaling_factor = vae_model.config.scaling_factor + shift_factor = vae_model.config.shift_factor + in_channels = vae_model.config.latent_channels + sample_size = dit_model.config.sample_size + + # trace + if not os.path.exists(vae_pt_path): + dummy_input = torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32) + vae_export = VaeExport(vae_model, scaling_factor, shift_factor) + torch.jit.trace(vae_export, dummy_input).save(vae_pt_path) + else: + logging.info("vae_pt_path already exists.") + + # compile + if not os.path.exists(vae_compiled_path): + model = torch.jit.load(vae_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), dtype=mindietorch.dtype.FLOAT)] + compile_vae(model, inputs, vae_compiled_path, args.soc) + else: + logging.info("vae_compiled_path already exists.") + + +def export(args): + pipeline = StableDiffusion3Pipeline.from_pretrained(args.model).to('cpu') + export_clip(pipeline, args) + export_dit(pipeline, args) + export_vae(pipeline, args) + + +def main(args): + mindietorch.set_device(args.device) + export(args) + print("Done.") + mindietorch.finalize() + + +if __name__ == "__main__": + args = parse_arguments() + main(args) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/prompts.txt b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..a375a0bb63931d0d5da6c6d91df1e14f870f47d0 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/prompts.txt @@ -0,0 +1,16 @@ +Beautiful illustration of The ocean. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Islands in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Seaports in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The waves. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Grassland. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Wheat. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Hut Tong. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The boat. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Pine trees. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Bamboo. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The temple. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Cloud in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Sun in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Spring. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Lotus. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Snow piles. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/requirements.txt b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..602842d4964c5cf2e299c0f6b33678d6f90d6825 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/requirements.txt @@ -0,0 +1,9 @@ +accelerate==0.31.0 +torch==2.1.0 +torchvision==0.16.0 +ftfy +diffusers==0.29.0 +transformers>=4.41.2 +tensorboard +Jinja2 +peft==0.11.1 \ No newline at end of file