From 3c1b5205206b448d053f578481bf55372b315a54 Mon Sep 17 00:00:00 2001 From: huanghao Date: Mon, 26 Aug 2024 11:43:15 +0800 Subject: [PATCH 1/7] =?UTF-8?q?SD3=E4=B8=8A=E5=BA=93=E7=AC=AC=E4=B8=80?= =?UTF-8?q?=E9=83=A8=E5=88=86=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../stable_diffusion_3/attention_patch.py | 32 ++ .../attention_processor.patch | 41 +++ .../stable_diffusion_3/background_runtime.py | 191 ++++++++++++ .../stable_diffusion_3/export_model.py | 273 ++++++++++++++++++ .../foundation/stable_diffusion_3/prompts.txt | 16 + .../stable_diffusion_3/requirements.txt | 9 + 6 files changed, 562 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/background_runtime.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/prompts.txt create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/requirements.txt diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py new file mode 100644 index 0000000000..97585e6af7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_patch.py @@ -0,0 +1,32 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", + "attention_processor.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main() diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch new file mode 100644 index 0000000000..dc22a411ea --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/attention_processor.patch @@ -0,0 +1,41 @@ +--- attention_processor.py 2024-08-14 12:41:10.528000000 +0000 ++++ attention_processor.py 2024-08-14 12:42:25.620000000 +0000 +@@ -1082,6 +1082,29 @@ + return hidden_states + + ++def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, ++ scale=None) -> torch.Tensor: ++ # Efficient implementation equivalent to the following: ++ L, S = query.size(-2), key.size(-2) ++ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale ++ attn_bias = torch.zeros(L, S, dtype=query.dtype) ++ if is_causal: ++ assert attn_mask is None ++ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) ++ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) ++ attn_bias.to(query.dtype) ++ ++ if attn_mask is not None: ++ if attn_mask.dtype == torch.bool: ++ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) ++ else: ++ attn_bias += attn_mask ++ attn_weight = query @ key.transpose(-2, -1) * scale_factor ++ attn_weight = torch.softmax(attn_weight, dim=-1) ++ attn_weight = torch.dropout(attn_weight, dropout_p, train=True) ++ return attn_weight @ value ++ ++ + class JointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + +@@ -1132,7 +1155,7 @@ + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + +- hidden_states = hidden_states = F.scaled_dot_product_attention( ++ hidden_states = hidden_states = scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/background_runtime.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/background_runtime.py new file mode 100644 index 0000000000..6f4935af2d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/background_runtime.py @@ -0,0 +1,191 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import time +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +import mindietorch + + +NUM_LAYERS = 28 + + +@dataclass +class RuntimeIOInfo: + input_shapes: List[tuple] + input_dtypes: List[type] + output_shapes: List[tuple] + output_dtypes: List[type] + + +class BackgroundRuntime: + def __init__( + self, + device_id: int, + model_path: str, + io_info: RuntimeIOInfo + ): + # Create a pipe for process synchronization + self.sync_pipe, sync_pipe_peer = mp.Pipe(duplex=True) + + # Create shared buffers + input_spaces = self.create_shared_buffers(io_info.input_shapes, + io_info.input_dtypes) + output_spaces = self.create_shared_buffers(io_info.output_shapes, + io_info.output_dtypes) + + # Build numpy arrays on the shared buffers + self.input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + self.output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + mp.set_start_method('fork', force=True) + self.p = mp.Process(target=self.run_infer, + args=[ + sync_pipe_peer, input_spaces, output_spaces, + io_info, device_id, model_path + ]) + self.p.start() + + # Wait until the sub process is ready + self.wait() + + @staticmethod + def create_shared_buffers(shapes: List[tuple], + dtypes: List[type]) -> List[mp.RawArray]: + buffers = [] + for shape, dtype in zip(shapes, dtypes): + size = 1 + for x in shape: + size *= x + + raw_array = mp.RawArray(np.ctypeslib.as_ctypes_type(dtype), size) + buffers.append(raw_array) + + return buffers + + def infer_asyn(self, feeds: List[np.ndarray]) -> None: + for i, _ in enumerate(self.input_arrays): + self.input_arrays[i][:] = feeds[i][:] + + self.sync_pipe.send('') + + def wait(self) -> None: + self.sync_pipe.recv() + + def get_outputs(self) -> List[np.ndarray]: + return self.output_arrays + + def wait_and_get_outputs(self) -> List[np.ndarray]: + self.wait() + return self.get_outputs() + + def infer(self, feeds: List[np.ndarray]) -> List[np.ndarray]: + self.infer_asyn(feeds) + return self.wait_and_get_outputs() + + def stop(self): + # Stop the sub process + self.sync_pipe.send('STOP') + + @staticmethod + def run_infer( + sync_pipe: mp.connection.Connection, + input_spaces: List[np.ndarray], + output_spaces: List[np.ndarray], + io_info: RuntimeIOInfo, + device_id: int, + model_path: str, + ) -> None: + # The sub process function + # Create a runtime + print(f"[info] bg device id: {device_id}") + + # Tell the main function that we are ready + model = torch.jit.load(model_path).eval() + + # Build numpy arrays on the shared buffers + input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + + output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + mindietorch.set_device(device_id) + + # Tell the main function that we are ready + sync_pipe.send('') + + infer_num = 0 + preprocess_time = 0 + infer_time = 0 + forward_time = 0 + + stream = mindietorch.npu.Stream(f"npu:{device_id}") + + # Keep looping until recived a 'STOP' + while sync_pipe.recv() != 'STOP': + start = time.time() + hidden_states, encoder_hidden_states, pooled_projections, timestep = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + hidden_states_npu = hidden_states.to(torch.float32).to(f"npu:{device_id}") + encoder_hidden_states_npu = encoder_hidden_states.to(torch.float32).to(f"npu:{device_id}") + pooled_projections_npu = pooled_projections.to(torch.float32).to(f"npu:{device_id}") + timestep_npu = timestep.to(torch.int64).to(f"npu:{device_id}") + + preprocess_time += time.time() - start + + start2 = time.time() + with mindietorch.npu.stream(stream): + inf_start = time.time() + output_npu = model( + hidden_states_npu, + encoder_hidden_states_npu, + pooled_projections_npu, + timestep_npu + ) + stream.synchronize() + inf_end = time.time() + + output_cpu = output_npu.to('cpu') + forward_time += inf_end - inf_start + infer_time += time.time() - start2 + + for i, _ in enumerate(output_arrays): + output = output_cpu.numpy() + output_arrays[i][:] = output[i][:] + + infer_num += 1 + sync_pipe.send('') + + infer_num /= NUM_LAYERS + + @classmethod + def clone(cls, device_id: int, model_path: str, runtime_info: RuntimeIOInfo) -> 'BackgroundRuntime': + # Get shapes, datatypes from an existed engine, + # then use them to create a BackgroundRuntime + return cls(device_id, model_path, runtime_info) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py new file mode 100644 index 0000000000..3608813b1e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py @@ -0,0 +1,273 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import argparse +from argparse import Namespace + +import torch +from diffusers import StableDiffusion3Pipeline +import mindietorch +from compile_model import * + + +def check_owner(path: str): + """ + check the path owner + param: the input path + return: whether the path owner is current user or not + """ + path_stat = os.stat(path) + path_owner, path_gid = path_stat.st_uid, path_stat.st_gid + user_check = path_owner == os.getuid() and path_owner == os.geteuid() + return path_owner == 0 or path_gid in os.getgroups() or user_check + + +def path_check(path: str): + """ + check path + param: path + return: data real path after check + """ + if os.path.islink(path) or path is None: + raise RuntimeError("The path should not be None or a symbolic link file.") + path = os.path.realpath(path) + if not check_owner(path): + raise RuntimeError("The path is not owned by current user or root.") + return path + + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt models.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-3-medium-diffusers", + help="Path or name of the pre-trained model.", + ) + parser.add_argument("-bs", "--batch_size", type=int, default=1, help="Batch size.") + parser.add_argument("-steps", "--steps", type=int, default=28, help="steps.") + parser.add_argument("-guid", "--guidance_scale", type=float, default=5.0, help="guidance_scale") + parser.add_argument("--use_cache", action="store_true", help="Use cache during inference.") + parser.add_argument("-p", "--parallel", action="store_true", + help="Export the unet of bs=1 for parallel inferencing.") + parser.add_argument("--soc", choices=["Duo", "A2"], default="A2", help="soc_version.") + parser.add_argument( + "--device", + default=0, + type=int, + help="NPU device", + ) + parser.add_argument( + "--height", + default=1024, + type=int, + help="image height", + ) + parser.add_argument( + "--width", + default=1024, + type=int, + help="image width" + ) + return parser.parse_args() + + +def trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path, t5_pt_path): + encoder_model = sd_pipeline.text_encoder + encoder_2_model = sd_pipeline.text_encoder_2 + t5_model = sd_pipeline.text_encoder_3 + max_position_embeddings = encoder_model.config.max_position_embeddings + dummy_input = torch.ones([batch_size, max_position_embeddings], dtype=torch.int64) + + if not os.path.exists(clip_pt_path): + clip_export = ClipExport(encoder_model) + torch.jit.trace(clip_export, dummy_input).save(clip_pt_path) + else: + logging.info("clip_pt_path already exists.") + + if not os.path.exists(clip2_pt_path): + clip2_export = ClipExport(encoder_2_model) + torch.jit.trace(clip2_export, dummy_input).save(clip2_pt_path) + else: + logging.info("clip2_pt_path already exists.") + + if not os.path.exists(t5_pt_path): + t5_export = ClipExport(t5_model) + torch.jit.trace(t5_export, dummy_input).save(t5_pt_path) + else: + logging.info("t5_pt_path already exists.") + + +def export_clip(sd_pipeline, args): + print("Exporting the text encoder...") + standard_path = path_check(args.output_dir) + clip_path = os.path.join(standard_path, "clip") + if not os.path.exists(clip_path): + os.makedirs(clip_path, mode=0o640) + batch_size = args.batch_size + clip_pt_path = os.path.join(clip_path, f"clip_bs{batch_size}.pt") + clip2_pt_path = os.path.join(clip_path, f"clip2_bs{batch_size}.pt") + t5_pt_path = os.path.join(clip_path, f"t5_bs{batch_size}.pt") + clip1_compiled_path = os.path.join(clip_path, + f"clip_bs{batch_size}_compile_{args.height}x{args.width}.ts") + clip2_compiled_path = os.path.join(clip_path, + f"clip2_bs{batch_size}_compile_{args.height}x{args.width}.ts") + t5_compiled_path = os.path.join(clip_path, + f"t5_bs{batch_size}_compile_{args.height}x{args.width}.ts") + + encoder_model = sd_pipeline.text_encoder + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path, t5_pt_path) + + # compile + if not os.path.exists(clip1_compiled_path): + model = torch.jit.load(clip_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip1_compiled_path, args.soc) + else: + logging.info("clip1_compiled_path already exists.") + if not os.path.exists(clip2_compiled_path): + model = torch.jit.load(clip2_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip2_compiled_path, args.soc) + else: + logging.info("clip2_compiled_path already exists.") + if not os.path.exists(t5_compiled_path): + model = torch.jit.load(t5_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, t5_compiled_path, args.soc) + else: + logging.info("t5_compiled_path already exists.") + + +def export_dit(sd_pipeline, args): + print("Exporting the dit...") + standard_path = path_check(args.output_dir) + dit_path = os.path.join(standard_path, "dit") + if not os.path.exists(dit_path): + os.makedirs(dit_path, mode=0o640) + + dit_model = sd_pipeline.transformer + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + if not args.parallel: + batch_size = args.batch_size * 2 + else: + batch_size = args.batch_size + sample_size = dit_model.config.sample_size + in_channels = dit_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings * 2 + + dit_pt_path = os.path.join(dit_path, f"dit_bs{batch_size}.pt") + dit_compiled_path = os.path.join(dit_path, + f"dit_bs{batch_size}_compile_{args.height}x{args.width}.ts") + + # trace + if not os.path.exists(dit_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size * 2], dtype=torch.float32 + ), + torch.ones([batch_size, encoder_hidden_size], dtype=torch.float32), + torch.ones([batch_size], dtype=torch.int64) + ) + dit = DiTExport(dit_model).eval() + torch.jit.trace(dit, dummy_input).save(dit_pt_path) + else: + logging.info("dit_pt_path already exists.") + + # compile + if not os.path.exists(dit_compiled_path): + model = torch.jit.load(dit_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size * 2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size,), dtype=mindietorch.dtype.INT64)] + compile_dit(model, inputs, dit_compiled_path, args.soc) + else: + logging.info("dit_compiled_path already exists.") + + +def export_vae(sd_pipeline, args): + print("Exporting the image decoder...") + standard_path = path_check(args.output_dir) + vae_path = os.path.join(standard_path, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + batch_size = args.batch_size + vae_pt_path = os.path.join(vae_path, f"vae_bs{batch_size}.pt") + vae_compiled_path = os.path.join(vae_path, + f"vae_bs{batch_size}_compile_{args.height}x{args.width}.ts") + + vae_model = sd_pipeline.vae + dit_model = sd_pipeline.transformer + scaling_factor = vae_model.config.scaling_factor + shift_factor = vae_model.config.shift_factor + in_channels = vae_model.config.latent_channels + sample_size = dit_model.config.sample_size + + # trace + if not os.path.exists(vae_pt_path): + dummy_input = torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32) + vae_export = VaeExport(vae_model, scaling_factor, shift_factor) + torch.jit.trace(vae_export, dummy_input).save(vae_pt_path) + else: + logging.info("vae_pt_path already exists.") + + # compile + if not os.path.exists(vae_compiled_path): + model = torch.jit.load(vae_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), dtype=mindietorch.dtype.FLOAT)] + compile_vae(model, inputs, vae_compiled_path, args.soc) + else: + logging.info("vae_compiled_path already exists.") + + +def export(args): + pipeline = StableDiffusion3Pipeline.from_pretrained(args.model).to('cpu') + export_clip(pipeline, args) + export_dit(pipeline, args) + export_vae(pipeline, args) + + +def main(args): + mindietorch.set_device(args.device) + export(args) + print("Done.") + mindietorch.finalize() + + +if __name__ == "__main__": + args = parse_arguments() + main(args) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/prompts.txt b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/prompts.txt new file mode 100644 index 0000000000..a375a0bb63 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/prompts.txt @@ -0,0 +1,16 @@ +Beautiful illustration of The ocean. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Islands in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Seaports in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The waves. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Grassland. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Wheat. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Hut Tong. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The boat. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Pine trees. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Bamboo. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The temple. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Cloud in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Sun in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Spring. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Lotus. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Snow piles. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/requirements.txt b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/requirements.txt new file mode 100644 index 0000000000..602842d496 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/requirements.txt @@ -0,0 +1,9 @@ +accelerate==0.31.0 +torch==2.1.0 +torchvision==0.16.0 +ftfy +diffusers==0.29.0 +transformers>=4.41.2 +tensorboard +Jinja2 +peft==0.11.1 \ No newline at end of file -- Gitee From 0d340be666727952c134698a8535070b9b1743f7 Mon Sep 17 00:00:00 2001 From: huanghao Date: Mon, 26 Aug 2024 12:41:37 +0800 Subject: [PATCH 2/7] =?UTF-8?q?SD3=E4=B8=8A=E5=BA=93=E7=AC=AC=E4=BA=8C?= =?UTF-8?q?=E9=83=A8=E5=88=86=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../foundation/stable_diffusion_3/README.md | 330 ++++++++++++++++++ .../stable_diffusion_3/compile_model.py | 89 +++++ 2 files changed, 419 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md new file mode 100644 index 0000000000..8f18fa737c --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md @@ -0,0 +1,330 @@ +# stable-diffusion3模型-推理指导 + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#section540883920406) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + +# 概述 + + SD3 由一组用于潜在扩散的专家管道组成: 在第一步中,使用基础模型生成(噪声)潜伏, 然后使用专门用于最终降噪步骤的细化模型[此处获得](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/) + +- 参考实现: + ```bash + # StableDiffusion3 + https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers + ``` + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | ------------------------- | ------------ | + | prompt | 1 x 77 | INT64| ND| + + +- 输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ----------- | + | output1 | 1 x 3 x 1024 x 1024 | FLOAT32 | NCHW | + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | -------- | -------- |--------| + | Python | 3.10.2 | - | +- | torch | 2.1.0 | - | + +该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能 + +# 快速上手 +## 获取源码 +1. 安装依赖。 + ```bash + pip3 install -r requirements.txt + + # 若要使用hpsv2验证精度,则还需要按照以下步骤安装hpsv2 + git clone https://github.com/tgxs002/HPSv2.git + cd HPSv2 + pip3 install -e . + ``` + +2. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +3. 代码修改 + + 执行命令: + + ```bash + # 若环境没有patch工具,请自行安装 + ``` + + ```bash + python3 attention_patch.py + ``` + +## 准备数据集 + +1. 获取原始数据集。 + 本模型输入文本信息生成图片,无需数据集。 + +## 模型推理 + +1. 模型转换。 + 使用Pytorch导出pt模型,然后使用MindIE推理引擎转换为适配昇腾的模型。 + + 0. 获取权重(可选) + + 可提前下载权重,放到代码同级目录下,以避免执行后面步骤时可能会出现下载失败。 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # xl + git clone https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers + ``` + + 1. 导出pt模型并进行编译。(可选) + + ```bash + # xl (执行时下载权重) + model_base="stabilityai/stable-diffusion-3-medium-diffusers" + + xl (使用上一步下载的权重) + model_base="./stable-diffusion-3-medium-diffusers" + ``` + + 执行命令: + + ```bash + # 800I A2,非并行 + python3 export_model.py --model ${model_base} --output_dir ./models --batch_size 1 --soc Ascend910B4 --device 0 + + # 300I Duo,并行 + python3 export_model.py --model ${model_base} --output_dir ./models --parallel --batch_size 1 --soc Ascend310P3 --device 0 + ``` + 参数说明: + - --model:模型权重路径 + - --output_dir: 存放导出模型的路径 + - --parallel: 【可选】导出适用于并行方案的模型 + - --batch_size: 设置batch_size, 默认值为1, 当前仅支持batch_size=1的场景 + - --soc:只支持Ascend910B4和Ascend310P3。默认为Ascend910B4。 + - --device:推理设备ID + +2. 开始推理验证。 + + 1. 开启cpu高性能模式 + ```bash + echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor + sysctl -w vm.swappiness=0 + sysctl -w kernel.numa_balancing=0 + ``` + + 2. 安装绑核工具 + ```bash + apt-get update + apt-get install numactl + ``` + 查询卡的NUMA node + ```shell + lspci -vs bus-id + ``` + bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字 + + 可通过lscpu获得NUMA node对应的CPU核数 + ```shell + NUMA node0: 0-23 + NUMA node1: 24-47 + NUMA node2: 48-71 + NUMA node3: 72-95 + ``` + 当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。 + + 3. 执行推理脚本。 + ```bash + # 不使用unetCache策略 + numactl -C 0-23 python3 stable_diffusion3_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --prompt_file_type plain \ + --device 0 \ + --save_dir ./results \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + + # 使用UnetCache策略,同时使用双卡并行策略 + numactl -C 0-23 python3 stable_diffusion3_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --prompt_file_type plain \ + --device 0,1 \ + --save_dir ./results_parallel \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 \ + --parallel + ``` + + 参数说明: + - --model:模型权重路径。 + - --output_dir:存放导出模型的目录。 + - --prompt_file:提示词文件。 + - --prompt_file_type: prompt文件类型,用于指定读取方式,可选plain,parti,hpsv2。 + - --save_dir:生成图片的存放目录。 + - --batch_size:模型batch size。 + - --steps:生成图片迭代次数。 + - --device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 + - --height:生成图像高度,当前只支持1024 + - --width:生成图像宽度,当前只支持1024 + + 非并行策略,执行完成后在`./results`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 + 并行策略,同时使用双卡并行策略,执行完成后在`./results_parallel`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + +## 精度验证 + + 由于生成的图片存在随机性,提供两种精度验证方法: + 1. CLIP-score(文图匹配度量):评估图片和输入文本的相关性,分数的取值范围为[-1, 1],越高越好。使用Parti数据集进行验证。 + 2. HPSv2(图片美学度量):评估生成图片的人类偏好评分,分数的取值范围为[0, 1],越高越好。使用HPSv2数据集进行验证 + + 注意,由于要生成的图片数量较多,进行完整的精度验证需要耗费很长的时间。 + + 1. 下载Parti数据集 + + ```bash + wget https://raw.githubusercontent.com/google-research/parti/main/PartiPrompts.tsv --no-check-certificate + ``` + + 2. 下载模型权重 + + ```bash + # Clip Score和HPSv2均需要使用的权重 + GIT_LFS_SKIP_SMUDGE=1 + git clone https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K + cd ./CLIP-ViT-H-14-laion2B-s32B-b79K + + # HPSv2权重 + wget https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt --no-check-certificate + ``` + 也可手动下载[权重](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/open_clip_pytorch_model.bin) + 将权重放到`CLIP-ViT-H-14-laion2B-s32B-b79K`目录下,手动下载[HPSv2权重](https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt)放到当前路径 + + 3. 使用推理脚本读取Parti数据集,生成图片 + + ```bash + # 不使用并行 + python3 stable_diffusion3_pipeline.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --max_num_prompts 0 \ + --device 0 \ + --save_dir ./results_PartiPrompts \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + + # 使用双卡并行策略 + python3 stable_diffusion3_pipeline.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --max_num_prompts 0 \ + --device 0,1 \ + --save_dir ./results_PartiPrompts_parallel \ + --steps 28 \ + --output_dir ./models \ + --use_cache \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + ``` + + 参数说明: + - --model:模型权重路径。 + - --output_dir:存放导出模型的目录。 + - --prompt_file:提示词文件。 + - --prompt_file_type: prompt文件类型,用于指定读取方式,可选plain,parti,hpsv2。注意使用hpsv2时,设置num_images_per_prompt=1即可。 + - --num_images_per_prompt: 每个prompt生成的图片数量。注意使用hpsv2时,设置num_images_per_prompt=1即可。 + - --max_num_prompts:限制prompt数量为前X个,0表示不限制。 + - --save_dir:生成图片的存放目录。 + - --batch_size:模型batch size。 + - --steps:生成图片迭代次数。 + - --device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 + + 不使用并行策略,执行完成后在`./results_PartiPrompts`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 + 使用双卡并行策略,执行完成后在`./results_PartiPrompts_parallel`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + + 4. 计算精度指标 + 1. CLIP-score + ```bash + python3 clip_score.py \ + --device=cpu \ + --image_info="image_info.json" \ + --model_name="ViT-H-14" \ + --model_weights_path="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" + ``` + + 参数说明: + - --device: 推理设备。 + - --image_info: 上一步生成的`image_info.json`文件。 + - --model_name: Clip模型名称。 + - --model_weights_path: Clip模型权重文件路径。 + + 执行完成后会在屏幕打印出精度计算结果。 + + 2. HPSv2 + ```bash + python3 hpsv2_score.py \ + --image_info="image_info.json" \ + --HPSv2_checkpoint="./HPS_v2_compressed.pt" \ + --clip_checkpoint="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" + ``` + + 参数说明: + - --image_info: 上一步生成的`image_info.json`文件。 + - --HPSv2_checkpoint: HPSv2模型权重文件路径。 + - --clip_checkpointh: Clip模型权重文件路径。 + + 执行完成后会在屏幕打印出精度计算结果。 + +# 模型推理性能&精度 + +调用ACL接口推理计算,性能参考下列数据。 + +### StableDiffusionxl +| 硬件形态 | cpu规格 | batch size | 迭代次数 | 优化手段 | 平均耗时 | 精度 | +| :------: | :------: | :------: |:----:| :------: |:-----:|:----------------:| +| A2 | 64核(arm) | 1 | 28 | w/o UnetCache | 6.15s | clip score 0.380 | + +性能测试需要独占npu和cpu \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py new file mode 100644 index 0000000000..c250153eba --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py @@ -0,0 +1,89 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import mindietorch +from mindietorch import _enums + +class ClipExport(torch.nn.Module): + def __init__(self, clip_model): + super().__init__() + self.clip_model = clip_model + + def forward(self, x, output_hidden_states=True, return_dict=False): + return self.clip_model(x, output_hidden_states=output_hidden_states, return_dict=return_dict) + +def compile_clip(model, inputs, clip_compiled_path, soc_version): + compiled_clip_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=False, + min_block_size=1, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0)) + torch.jit.save(compiled_clip_model, clip_compiled_path) + +class VaeExport(torch.nn.Module): + def __init__(self, vae_model, scaling_factor, shift_factor): + super().__init__() + self.vae_model = vae_model + self.scaling_factor = scaling_factor + self.shift_factor = shift_factor + + def forward(self, latents): + latents = (latents / self.scaling_factor) + self.shift_factor + image = self.vae_model.decode(latents, return_dict=False)[0] + return image + +def compile_vae(model, inputs, vae_compiled_path, soc_version): + compiled_vae_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0)) + torch.jit.save(compiled_vae_model, vae_compiled_path) + +class DiTExport(torch.nn.Module): + def __init__(self, dit_model): + super().__init__() + self.dit_model = dit_model + + def forward( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + ): + return self.dit_model(hidden_states, encoder_hidden_states, pooled_projections, + timestep, None, False)[0] + +def compile_dit(model, inputs, dit_compiled_path, soc_version): + compiled_dit_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0)) + torch.jit.save(compiled_dit_model, dit_compiled_path) -- Gitee From 3379181a6017db3a68d5425d159d7c0f12579d8d Mon Sep 17 00:00:00 2001 From: huanghao Date: Mon, 26 Aug 2024 15:46:10 +0800 Subject: [PATCH 3/7] =?UTF-8?q?SD3=E4=B8=8A=E5=BA=93=E7=AC=AC=E4=BA=8C?= =?UTF-8?q?=E9=83=A8=E5=88=86=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../foundation/stable_diffusion_3/README.md | 30 ++++---- .../stable_diffusion_3/compile_model.py | 68 ++++++++++--------- 2 files changed, 49 insertions(+), 49 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md index 8f18fa737c..25f45779a7 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md @@ -44,10 +44,10 @@ **表 1** 版本配套表 - | 配套 | 版本 | 环境准备指导 | - | -------- | -------- |--------| - | Python | 3.10.2 | - | -- | torch | 2.1.0 | - | + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | 该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能 @@ -73,21 +73,15 @@ ``` 3. 代码修改 - - 执行命令: - +- 若环境没有patch工具,请自行安装: ```bash - # 若环境没有patch工具,请自行安装 + apt update + apt install patch ``` - +- 执行命令: ```bash python3 attention_patch.py ``` - -## 准备数据集 - -1. 获取原始数据集。 - 本模型输入文本信息生成图片,无需数据集。 ## 模型推理 @@ -102,17 +96,17 @@ # 需要使用 git-lfs (https://git-lfs.com) git lfs install - # xl + # 下载sd3权重 git clone https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers ``` 1. 导出pt模型并进行编译。(可选) ```bash - # xl (执行时下载权重) + # sd3 (执行时下载权重) model_base="stabilityai/stable-diffusion-3-medium-diffusers" - xl (使用上一步下载的权重) + # sd3 (使用上一步下载的权重) model_base="./stable-diffusion-3-medium-diffusers" ``` @@ -322,7 +316,7 @@ 调用ACL接口推理计算,性能参考下列数据。 -### StableDiffusionxl +### StableDiffusion3 | 硬件形态 | cpu规格 | batch size | 迭代次数 | 优化手段 | 平均耗时 | 精度 | | :------: | :------: | :------: |:----:| :------: |:-----:|:----------------:| | A2 | 64核(arm) | 1 | 28 | w/o UnetCache | 6.15s | clip score 0.380 | diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py index c250153eba..b48994f182 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/compile_model.py @@ -13,9 +13,36 @@ # limitations under the License. import torch +from dataclasses import dataclass +from typing import List import mindietorch from mindietorch import _enums + +@dataclass +class CompileParam: + inputs: List[mindietorch.Input] = None + soc_version: str = "" + allow_tensor_replace_int: bool = True + require_full_compilation: bool = True + truncate_long_and_double: bool = True + min_block_size: int = 1 + + +def common_compile(model, compiled_path, compile_param): + compiled_model = ( + mindietorch.compile(model, + inputs=compile_param.inputs, + allow_tensor_replace_int=compile_param.allow_tensor_replace_int, + require_full_compilation=compile_param.require_full_compilation, + truncate_long_and_double=compile_param.truncate_long_and_double, + min_block_size=compile_param.min_block_size, + soc_version=compile_param.soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0)) + torch.jit.save(compiled_model, compiled_path) + + class ClipExport(torch.nn.Module): def __init__(self, clip_model): super().__init__() @@ -24,18 +51,11 @@ class ClipExport(torch.nn.Module): def forward(self, x, output_hidden_states=True, return_dict=False): return self.clip_model(x, output_hidden_states=output_hidden_states, return_dict=return_dict) + def compile_clip(model, inputs, clip_compiled_path, soc_version): - compiled_clip_model = ( - mindietorch.compile(model, - inputs=inputs, - allow_tensor_replace_int=True, - require_full_compilation=False, - truncate_long_and_double=False, - min_block_size=1, - soc_version=soc_version, - precision_policy=_enums.PrecisionPolicy.FP16, - optimization_level=0)) - torch.jit.save(compiled_clip_model, clip_compiled_path) + clip_param = CompileParam(inputs, soc_version, True, False, False) + common_compile(model, clip_compiled_path, clip_param) + class VaeExport(torch.nn.Module): def __init__(self, vae_model, scaling_factor, shift_factor): @@ -50,16 +70,9 @@ class VaeExport(torch.nn.Module): return image def compile_vae(model, inputs, vae_compiled_path, soc_version): - compiled_vae_model = ( - mindietorch.compile(model, - inputs=inputs, - allow_tensor_replace_int=True, - require_full_compilation=True, - truncate_long_and_double=True, - soc_version=soc_version, - precision_policy=_enums.PrecisionPolicy.FP16, - optimization_level=0)) - torch.jit.save(compiled_vae_model, vae_compiled_path) + vae_param = CompileParam(inputs, soc_version) + common_compile(model, vae_compiled_path, vae_param) + class DiTExport(torch.nn.Module): def __init__(self, dit_model): @@ -76,14 +89,7 @@ class DiTExport(torch.nn.Module): return self.dit_model(hidden_states, encoder_hidden_states, pooled_projections, timestep, None, False)[0] + def compile_dit(model, inputs, dit_compiled_path, soc_version): - compiled_dit_model = ( - mindietorch.compile(model, - inputs=inputs, - allow_tensor_replace_int=True, - require_full_compilation=True, - truncate_long_and_double=True, - soc_version=soc_version, - precision_policy=_enums.PrecisionPolicy.FP16, - optimization_level=0)) - torch.jit.save(compiled_dit_model, dit_compiled_path) + dit_param = CompileParam(inputs, soc_version) + common_compile(model, dit_compiled_path, dit_param) -- Gitee From 264145a0fd72a89e0321721f6d43a80bc5474951 Mon Sep 17 00:00:00 2001 From: huanghao Date: Mon, 26 Aug 2024 16:18:56 +0800 Subject: [PATCH 4/7] =?UTF-8?q?SD3=E4=B8=8A=E5=BA=93=E7=AC=AC=E4=B8=89?= =?UTF-8?q?=E9=83=A8=E5=88=86=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../stable_diffusion3_pipeline.py | 451 ++++++++++++++++++ 1 file changed, 451 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py new file mode 100644 index 0000000000..30c4eb7948 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py @@ -0,0 +1,451 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import csv +import json +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union +import numpy as np +import torch +import mindietorch +from diffusers import StableDiffusion3Pipeline +from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps +from background_runtime import BackgroundRuntime, RuntimeIOInfo + +clip_time = 0 +t5_time = 0 +dit_time = 0 +vae_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 + + +class AIEStableDiffusion3Pipeline(StableDiffusion3Pipeline): + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + ): + device = f"npu:{self.device_0}" + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logging.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + + global t5_time + start = time.time() + prompt_embeds = self.compiled_t5_model(text_input_ids.to(device))[0].to('cpu') + t5_time += (time.time() - start) + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device='cpu') + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + device = f"npu:{self.device_0}" + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.compiled_clip_model, self.compiled_clip_model_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logging.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + + global clip_time + start = time.time() + prompt_embeds = text_encoder(text_input_ids.to(device)) + pooled_prompt_embeds = prompt_embeds[0].to('cpu') + clip_time += (time.time() - start) + + if clip_skip is None: + prompt_embeds = prompt_embeds[2][-2].to('cpu') + else: + prompt_embeds = prompt_embeds[2][-(clip_skip + 2)].to('cpu') + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device='cpu') + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + @torch.no_grad() + def forward( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + global p1_time, p2_time, p3_time + start = time.time() + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + p1_time += (time.time() - start) + start1 = time.time() + + if self.do_classifier_free_guidance and not self.use_parallel_inferencing: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + else: + prompt_embeds, prompt_embeds_1 = negative_prompt_embeds, prompt_embeds + pooled_prompt_embeds, pooled_prompt_embeds_1 = negative_pooled_prompt_embeds, pooled_prompt_embeds + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + global dit_time + global vae_time + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + if not self.use_parallel_inferencing and self.do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(torch.int64) + timestep_npu = timestep.to(f"npu:{self.device_0}") + else: + latent_model_input = latents + timestep = t.to(torch.int64) + self.dit_bg.infer_asyn([ + latent_model_input.numpy(), + prompt_embeds_1.numpy(), + pooled_prompt_embeds_1.numpy(), + timestep[None].numpy().astype(np.int64) + ]) + timestep_npu = timestep[None].to(f"npu:{self.device_0}") + + latent_model_input_npu = latent_model_input.to(f"npu:{self.device_0}") + prompt_embeds_npu = prompt_embeds.to(f"npu:{self.device_0}") + pooled_prompt_embeds_npu = pooled_prompt_embeds.to(f"npu:{self.device_0}") + + start = time.time() + noise_pred = self.compiled_dit_model( + latent_model_input_npu, + prompt_embeds_npu, + pooled_prompt_embeds_npu, + timestep_npu + ).to("cpu") + dit_time += (time.time() - start) + + # perform guidance + if self.do_classifier_free_guidance: + if self.use_parallel_inferencing: + noise_pred_text = torch.from_numpy(self.dit_bg.wait_and_get_outputs()[0]) + else: + noise_pred, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred + self.guidance_scale * (noise_pred_text - noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + p2_time += (time.time() - start1) + start2 = time.time() + + if output_type == "latent": + image = latents + else: + start = time.time() + image = self.compiled_vae_model(latents.to(f"npu:{self.device_0}")).to("cpu") + vae_time += (time.time() - start) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + p3_time += (time.time() - start2) + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) \ No newline at end of file -- Gitee From 1b5e239656ce73658c6f5032ebdd124818553c4a Mon Sep 17 00:00:00 2001 From: huanghao Date: Mon, 26 Aug 2024 16:30:20 +0800 Subject: [PATCH 5/7] =?UTF-8?q?SD3=E4=B8=8A=E5=BA=93=E7=AC=AC=E4=B8=89?= =?UTF-8?q?=E9=83=A8=E5=88=86=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../foundation/stable_diffusion_3/stable_diffusion3_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py index 30c4eb7948..1c1a1aaa32 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py @@ -1,4 +1,4 @@ -# Copyright 2024 Huawei Technologies Co., Ltd +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -- Gitee From 03b9a005a174f02e7b8da1923b988d7554fc4213 Mon Sep 17 00:00:00 2001 From: huanghao Date: Mon, 26 Aug 2024 17:37:44 +0800 Subject: [PATCH 6/7] =?UTF-8?q?SD3=E4=B8=8A=E5=BA=93=E7=AC=AC=E5=9B=9B?= =?UTF-8?q?=E9=83=A8=E5=88=86=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../stable_diffusion3_pipeline.py | 356 +++++++++++++++++- 1 file changed, 355 insertions(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py index 1c1a1aaa32..e5cea8855a 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline.py @@ -36,7 +36,171 @@ p2_time = 0 p3_time = 0 +class PromptLoader: + def __init__( + self, + prompt_file: str, + prompt_file_type: str, + batch_size: int, + num_images_per_prompt: int = 1, + max_num_prompts: int = 0 + ): + self.prompts = [] + self.catagories = ['Not_specified'] + self.batch_size = batch_size + self.num_images_per_prompt = num_images_per_prompt + + if prompt_file_type == 'plain': + self.load_prompts_plain(prompt_file, max_num_prompts) + elif prompt_file_type == 'parti': + self.load_prompts_parti(prompt_file, max_num_prompts) + elif prompt_file_type == 'hpsv2': + self.load_prompts_hpsv2(max_num_prompts) + else: + print("This operation is not supported!") + + self.current_id = 0 + self.inner_id = 0 + + def __len__(self): + return len(self.prompts) * self.num_images_per_prompt + + def __iter__(self): + return self + + def __next__(self): + if self.current_id == len(self.prompts): + raise StopIteration + + ret = { + 'prompts': [], + 'catagories': [], + 'save_names': [], + 'n_prompts': self.batch_size, + } + for _ in range(self.batch_size): + if self.current_id == len(self.prompts): + ret['prompts'].append('') + ret['save_names'].append('') + ret['catagories'].append('') + ret['n_prompts'] -= 1 + + else: + prompt, catagory_id = self.prompts[self.current_id] + ret['prompts'].append(prompt) + ret['catagories'].append(self.catagories[catagory_id]) + ret['save_names'].append(f'{self.current_id}_{self.inner_id}') + + self.inner_id += 1 + if self.inner_id == self.num_images_per_prompt: + self.inner_id = 0 + self.current_id += 1 + + return ret + + def load_prompts_plain(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + for i, line in enumerate(f): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line.strip() + self.prompts.append((prompt, 0)) + + def load_prompts_parti(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + # Skip the first line + next(f) + tsv_file = csv.reader(f, delimiter="\t") + for i, line in enumerate(tsv_file): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line[0] + catagory = line[1] + if catagory not in self.catagories: + self.catagories.append(catagory) + + catagory_id = self.catagories.index(catagory) + self.prompts.append((prompt, catagory_id)) + + def load_prompts_hpsv2(self, max_num_prompts: int): + with open('hpsv2_benchmark_prompts.json', 'r') as file: + all_prompts = json.load(file) + count = 0 + for style, prompts in all_prompts.items(): + for prompt in prompts: + count += 1 + if max_num_prompts and count >= max_num_prompts: + break + + if style not in self.catagories: + self.catagories.append(style) + + catagory_id = self.catagories.index(style) + self.prompts.append((prompt, catagory_id)) + + class AIEStableDiffusion3Pipeline(StableDiffusion3Pipeline): + def parser_args(self, args): + self.args = args + self.is_init = False + if isinstance(self.args.device, list): + self.device_0, self.device_1 = args.device + else: + self.device_0 = args.device + self.data = None + + def compile_aie_model(self): + if self.is_init: + return + size = self.args.batch_size + if hasattr(self, 'device_1'): + batch_size = self.args.batch_size + else: + batch_size = self.args.batch_size * 2 + + tail = f"_{self.args.height}x{self.args.width}" + vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_bs{size}_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, f"clip/clip_bs{size}_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, f"clip/clip2_bs{size}_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + t5_compiled_path = os.path.join(self.args.output_dir, f"clip/t5_bs{size}_compile{tail}.ts") + self.compiled_t5_model = torch.jit.load(t5_compiled_path).eval() + + dit_compiled_path = os.path.join(self.args.output_dir, f"dit/dit_bs{batch_size}_compile{tail}.ts") + self.compiled_dit_model = torch.jit.load(dit_compiled_path).eval() + + self.use_parallel_inferencing = False + + if hasattr(self, 'device_1'): + sample_size = self.transformer.config.sample_size + in_channels = self.transformer.config.in_channels + encoder_hidden_size_2 = self.text_encoder_2.config.hidden_size + encoder_hidden_size = self.text_encoder.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = self.text_encoder.config.max_position_embeddings * 2 + + runtime_info = RuntimeIOInfo( + input_shapes=[ + (batch_size, in_channels, sample_size, sample_size), + (batch_size, max_position_embeddings, encoder_hidden_size * 2), + (batch_size, encoder_hidden_size), + (batch_size,), + ], + input_dtypes=[np.float32, np.float32, np.float32, np.int64], + output_shapes=[(batch_size, in_channels, sample_size, sample_size)], + output_dtypes=[np.float32] + ) + self.dit_bg = BackgroundRuntime.clone(self.device_1, dit_compiled_path, runtime_info) + self.use_parallel_inferencing = True + + self.is_init = True + def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -448,4 +612,194 @@ class AIEStableDiffusion3Pipeline(StableDiffusion3Pipeline): if not return_dict: return (image,) - return StableDiffusion3PipelineOutput(images=image) \ No newline at end of file + return StableDiffusion3PipelineOutput(images=image) + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-3-medium-diffusers", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="A text file of prompts for generating images.", + ) + parser.add_argument( + "--prompt_file_type", + choices=["plain", "parti", "hpsv2"], + default="plain", + help="Type of prompt file.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--info_file_save_path", + type=str, + default="./image_info.json", + help="Path to save image information file.", + ) + parser.add_argument( + "--steps", + type=int, + default=28, + help="Number of inference steps.", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=0, + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "--num_images_per_prompt", + default=1, + type=int, + help="Number of images generated for each prompt.", + ) + parser.add_argument( + "--max_num_prompts", + default=0, + type=int, + help="Limit the number of prompts (0: no limit).", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) + parser.add_argument( + "--scheduler", + choices=["FlowMatchEuler"], + default="FlowMatchEuler", + help="Type of Sampling methods. Default FlowMatchEuler", + ) + parser.add_argument( + "--height", + default=1024, + type=int, + help="image height", + ) + parser.add_argument( + "--width", + default=1024, + type=int, + help="image width" + ) + + return parser.parse_args() + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + pipe = AIEStableDiffusion3Pipeline.from_pretrained(args.model).to("cpu") + pipe.parser_args(args) + pipe.compile_aie_model() + if isinstance(args.device, list): + mindietorch.set_device(args.device[0]) + else: + mindietorch.set_device(args.device) + + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt, + args.max_num_prompts) + + infer_num = 0 + image_info = [] + current_prompt = None + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + + start_time = time.time() + images = pipe.forward( + prompts, + negative_prompt="", + num_inference_steps=args.steps, + guidance_scale=7.5, + ) + if i > 4: # do not count the time spent inferring the first 0 to 4 images + use_time += time.time() - start_time + + for j in range(n_prompts): + image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + + image_info[-1]['images'].append(image_save_path) + + infer_num = infer_num - 5 # do not count the time spent inferring the first 5 images + print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" + f"average time: {use_time / infer_num:.3f}s\n") + + if hasattr(pipe, 'device_1'): + if (pipe.dit_bg): + pipe.dit_bg.stop() + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + mindietorch.finalize() + + +if __name__ == "__main__": + main() -- Gitee From 8349c4cc0d4b143a801937cd87c9a5cc876a239d Mon Sep 17 00:00:00 2001 From: huanghao Date: Tue, 27 Aug 2024 16:04:23 +0800 Subject: [PATCH 7/7] =?UTF-8?q?SD3=E4=BF=AE=E6=94=B9soc=20choices=E9=94=99?= =?UTF-8?q?=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/foundation/stable_diffusion_3/export_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py index 3608813b1e..187cceea83 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py @@ -70,7 +70,7 @@ def parse_arguments() -> Namespace: parser.add_argument("--use_cache", action="store_true", help="Use cache during inference.") parser.add_argument("-p", "--parallel", action="store_true", help="Export the unet of bs=1 for parallel inferencing.") - parser.add_argument("--soc", choices=["Duo", "A2"], default="A2", help="soc_version.") + parser.add_argument("--soc", choices=["Ascend910B4", "Ascend310P3"], default="Ascend910B4", help="soc_version.") parser.add_argument( "--device", default=0, -- Gitee