From f36987df5cc7004040584bc8ce6fb67bbe7ef158 Mon Sep 17 00:00:00 2001 From: guowenna Date: Tue, 16 Jul 2024 10:32:58 +0800 Subject: [PATCH] =?UTF-8?q?deepcache(=E6=BF=80=E8=BF=9B=E7=89=88)+deepcach?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../background_runtime_cache_faster.py | 194 +++++++ .../foundation/stable_diffusion/export_ts.py | 79 ++- .../stable_diffusion_pipeline.py | 422 +++++++++----- .../stable_diffusion_pipeline_parallel.py | 528 ++++++++++++------ .../stable_diffusion_unet_patch.py | 1 + .../stable_diffusion/unet_2d_blocks.patch | 69 +++ .../stable_diffusion/unet_2d_condition.patch | 128 +++-- 7 files changed, 1047 insertions(+), 374 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/background_runtime_cache_faster.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/unet_2d_blocks.patch diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/background_runtime_cache_faster.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/background_runtime_cache_faster.py new file mode 100644 index 0000000000..f4f4fbf405 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/background_runtime_cache_faster.py @@ -0,0 +1,194 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import time +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +import mindietorch + + +@dataclass +class RuntimeIOInfoCacheFaster: + input_shapes: List[tuple] + input_dtypes: List[type] + output_shapes: List[tuple] + output_dtypes: List[type] + + +class BackgroundRuntimeCacheFaster: + def __init__( + self, + device_id: int, + model_path: str, + io_info: RuntimeIOInfoCacheFaster + ): + # Create a pipe for process synchronization + self.sync_pipe, sync_pipe_peer = mp.Pipe(duplex=True) + + # Create shared buffers + input_spaces = self.create_shared_buffers(io_info.input_shapes, + io_info.input_dtypes) + output_spaces = self.create_shared_buffers(io_info.output_shapes, + io_info.output_dtypes) + + # Build numpy arrays on the shared buffers + self.input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + self.output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + mp.set_start_method('fork', force=True) + self.p = mp.Process(target=self.run_infer, + args=[ + sync_pipe_peer, input_spaces, output_spaces, + io_info, device_id, model_path + ]) + self.p.start() + + # Wait until the sub process is ready + self.wait() + + @staticmethod + def create_shared_buffers(shapes: List[tuple], + dtypes: List[type]) -> List[mp.RawArray]: + buffers = [] + for shape, dtype in zip(shapes, dtypes): + size = 1 + for x in shape: + size *= x + + raw_array = mp.RawArray(np.ctypeslib.as_ctypes_type(dtype), size) + buffers.append(raw_array) + + return buffers + + def infer_asyn(self, feeds: List[np.ndarray], skip) -> None: + for i, _ in enumerate(self.input_arrays): + self.input_arrays[i][:] = feeds[i][:] + + if skip: + self.sync_pipe.send('skip') + else: + self.sync_pipe.send('cache') + + def wait(self) -> None: + self.sync_pipe.recv() + + def get_outputs(self) -> List[np.ndarray]: + return self.output_arrays + + def wait_and_get_outputs(self) -> List[np.ndarray]: + self.wait() + return self.get_outputs() + + def infer(self, feeds: List[np.ndarray]) -> List[np.ndarray]: + self.infer_asyn(feeds) + return self.wait_and_get_outputs() + + def stop(self): + # Stop the sub process + self.sync_pipe.send('STOP') + + @staticmethod + def run_infer( + sync_pipe: mp.connection.Connection, + input_spaces: List[np.ndarray], + output_spaces: List[np.ndarray], + io_info: RuntimeIOInfoCacheFaster, + device_id: int, + model_path: list, + ) -> None: + # The sub process function + # Create a runtime + mindietorch.set_device(device_id) + print(f"[info] bg device id: {device_id}") + + # Tell the main function that we are ready + model_cache = torch.jit.load(model_path[0]).eval() + model_skip = torch.jit.load(model_path[1]).eval() + + # Build numpy arrays on the shared buffers + input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + + output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + # Tell the main function that we are ready + sync_pipe.send('') + + stream = mindietorch.npu.Stream(f"npu:{device_id}") + + return_cache = None + + # Keep looping until recived a 'STOP' + while True: + flag = sync_pipe.recv() + if flag == 'STOP': + break + + if flag == 'cache': + sample, timestep, encoder_hidden_states, return_flag, return_faster_flag = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + else: + sample, timestep, encoder_hidden_states, return_flag, return_faster_flag = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + + sample_npu = sample.to(torch.float32).to(f"npu:{device_id}") + timestep_npu = timestep.to(torch.int64).to(f"npu:{device_id}") + encoder_hidden_states_npu = encoder_hidden_states.to(torch.float32).to(f"npu:{device_id}") + flag_npu = return_flag.to(torch.int64).to(f"npu:{device_id}") + faster_flag_npu = return_faster_flag.to(torch.int64).to(f"npu:{device_id}") + + if flag == 'cache': + with mindietorch.npu.stream(stream): + output_npu = model_cache(sample_npu, timestep_npu, encoder_hidden_states_npu, flag_npu, faster_flag_npu) + stream.synchronize() + + output_cpu0 = output_npu[0].to('cpu') + output0 = output_cpu0.numpy() + output_arrays[0][:] = output0 + + return_cache = output_npu[1] + return_cache_faster = output_npu[2] + else: + with mindietorch.npu.stream(stream): + output_npu = model_skip(sample_npu, timestep_npu, encoder_hidden_states_npu, flag_npu, faster_flag_npu, return_cache, return_cache_faster) + stream.synchronize() + + output_cpu0 = output_npu.to('cpu') + output0 = output_cpu0.numpy() + output_arrays[0][:] = output0 + + sync_pipe.send('') + + @classmethod + def clone(cls, device_id: int, model_path: str, runtime_info: RuntimeIOInfoCacheFaster) -> 'BackgroundRuntimeCacheFaster': + # Get shapes, datatypes from an existed engine, + # then use them to create a BackgroundRuntimeCache + return cls(device_id, model_path, runtime_info) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/export_ts.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/export_ts.py index ecbb1e1fba..4ba0cd0dce 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/export_ts.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/export_ts.py @@ -60,10 +60,15 @@ def parse_arguments() -> Namespace: help="guidance_scale" ) parser.add_argument( - "--use_cache", + "--use_cache", action="store_true", help="Use cache during inference." ) + parser.add_argument( + "--use_cache_faster", + action="store_true", + help="Use cache with faster during inference." + ) parser.add_argument( "-p", "--parallel", @@ -180,7 +185,7 @@ def export_unet(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: torch.ones([1], dtype=torch.int64), torch.ones([batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32), torch.ones([1], dtype=torch.int64), - torch.ones([batch_size, 640, sample_size, sample_size], dtype=torch.float32), + torch.ones([batch_size, 320, sample_size, sample_size], dtype=torch.float32), ) else: dummy_input = ( @@ -196,6 +201,61 @@ def export_unet(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: torch.jit.trace(unet, dummy_input).save(unet_pt_path) +class UnetExportFaster(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward(self, sample, timestep, encoder_hidden_states, if_skip, if_faster, inputCache=None, inputFasterCache=None): + if if_skip: + return self.unet_model(sample, timestep, encoder_hidden_states, if_skip=if_skip, if_faster=if_faster, inputCache=inputCache, inputFasterCache=inputFasterCache)[0] + else: + return self.unet_model(sample, timestep, encoder_hidden_states, if_skip=if_skip, if_faster=if_faster) + + +def export_unet_faster(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: int, if_skip: int, if_faster: int) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}_{if_skip}_{if_faster}.pt") + if os.path.exists(unet_pt_path): + return + + unet_model = sd_pipeline.unet + clip_model = sd_pipeline.text_encoder + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size = clip_model.config.hidden_size + max_position_embeddings = clip_model.config.max_position_embeddings + + if if_skip: + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, 320, sample_size, sample_size], dtype=torch.float32), + torch.ones([batch_size, 2*320, sample_size, sample_size], dtype=torch.float32), + ) + else: + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32), + torch.zeros([1], dtype=torch.int64), + torch.ones([1], dtype=torch.int64), + ) + + unet = UnetExportFaster(unet_model) + unet.eval() + + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + class CatExport(torch.nn.Module): def __init__(self, scale_model_input): super(CatExport, self).__init__() @@ -397,7 +457,7 @@ def export_vae(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: torch.jit.trace(vae_export, dummy_input).save(vae_pt_path) -def export(model_path: str, save_dir: str, batch_size: int, steps: int, guidance_scale: float, use_cache: bool, parallel: bool) -> None: +def export(model_path: str, save_dir: str, batch_size: int, steps: int, guidance_scale: float, use_cache: bool, use_cache_faster: bool, parallel: bool) -> None: pipeline = StableDiffusionPipeline.from_pretrained(model_path).to("cpu") export_clip(pipeline, save_dir, batch_size) @@ -414,6 +474,17 @@ def export(model_path: str, save_dir: str, batch_size: int, steps: int, guidance export_unet(pipeline, save_dir, batch_size * 2, 0) # 单卡, unet_skip export_unet(pipeline, save_dir, batch_size * 2, 1) + if use_cache_faster: + if parallel: + # 双卡, unet_cache带faster + export_unet_faster(pipeline, save_dir, batch_size, 0, 1) + # 双卡, unet_skip带faster + export_unet_faster(pipeline, save_dir, batch_size, 1, 1) + else: + # 单卡, unet_cache带faster + export_unet_faster(pipeline, save_dir, batch_size * 2, 0, 1) + # 单卡, unet_skip带faster + export_unet_faster(pipeline, save_dir, batch_size * 2, 1, 1) else: if parallel: # 双卡不带unetcache @@ -434,7 +505,7 @@ def export(model_path: str, save_dir: str, batch_size: int, steps: int, guidance def main(): args = parse_arguments() - export(args.model, args.output_dir, args.batch_size, args.steps, args.guidance_scale, args.use_cache, args.parallel) + export(args.model, args.output_dir, args.batch_size, args.steps, args.guidance_scale, args.use_cache, args.use_cache_faster, args.parallel) print("Done.") diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline.py index a953ae4805..16d986539b 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline.py @@ -34,6 +34,7 @@ p2_time = 0 p3_time = 0 scheduler_time = 0 + class PromptLoader: def __init__( self, @@ -150,15 +151,15 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): self.compiled_clip_model = ( mindietorch.compile(model, - inputs=[mindietorch.Input((self.args.batch_size, - max_position_embeddings), - dtype=mindietorch.dtype.INT64)], - allow_tensor_replace_int=True, - require_full_compilation=True, - truncate_long_and_double=True, - precision_policy=_enums.PrecisionPolicy.FP16, - soc_version=soc_version, - optimization_level=0)) + inputs=[mindietorch.Input((self.args.batch_size, + max_position_embeddings), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) torch.jit.save(self.compiled_clip_model, clip_compiled_path) vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_bs{size}_aie_compile.ts") @@ -169,17 +170,17 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): self.compiled_vae_model = ( mindietorch.compile(model, - inputs=[ - mindietorch.Input((self.args.batch_size, in_channels, - sample_size, sample_size), - dtype=mindietorch.dtype.FLOAT)], - allow_tensor_replace_int=True, - require_full_compilation=True, - truncate_long_and_double=True, - soc_version=soc_version, - precision_policy=_enums.PrecisionPolicy.FP16, - optimization_level=0 - )) + inputs=[ + mindietorch.Input((self.args.batch_size, in_channels, + sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) torch.jit.save(self.compiled_vae_model, vae_compiled_path) scheduler_compiled_path = os.path.join(self.args.output_dir, f"ddim/ddim{batch_size}_aie_compile.ts") @@ -190,77 +191,49 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): self.compiled_scheduler = ( mindietorch.compile(model, - inputs=[mindietorch.Input((batch_size, - in_channels, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64), - mindietorch.Input((batch_size//2, - in_channels, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64)], - allow_tensor_replace_int=True, - require_full_compilation=False, - truncate_long_and_double=False, - precision_policy=_enums.PrecisionPolicy.FP16, - soc_version=soc_version, - optimization_level=0)) + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size // 2, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=False, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) torch.jit.save(self.compiled_scheduler, scheduler_compiled_path) cat_compiled_path = os.path.join(self.args.output_dir, "cat/cat_aie_compile.ts") if os.path.exists(cat_compiled_path): - self.compiled_cat = torch.jit.load(cat_compiled_path).eval() + self.compiled_cat = torch.jit.load(cat_compiled_path).eval() else: model = torch.jit.load(os.path.join(self.args.output_dir, "cat/cat.pt")).eval() self.compiled_cat = ( mindietorch.compile(model, - inputs=[mindietorch.Input((batch_size//2, - in_channels, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.FLOAT)], - allow_tensor_replace_int=True, - require_full_compilation=True, - truncate_long_and_double=True, - precision_policy=_enums.PrecisionPolicy.FP16, - soc_version=soc_version, - optimization_level=0)) - torch.jit.save(self.compiled_cat, cat_compiled_path) - - if not args.use_cache: - unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile.ts") - if os.path.exists(unet_compiled_path): - self.compiled_unet = torch.jit.load(unet_compiled_path).eval() - else: - model = torch.jit.load(os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}.pt")).eval() - - self.compiled_unet = ( - mindietorch.compile(model, - inputs=[mindietorch.Input((batch_size, - in_channels, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT), + inputs=[mindietorch.Input((batch_size // 2, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64), - mindietorch.Input((batch_size, - max_position_embeddings, - encoder_hidden_size), - dtype=mindietorch.dtype.FLOAT)], + dtype=mindietorch.dtype.FLOAT)], allow_tensor_replace_int=True, require_full_compilation=True, truncate_long_and_double=True, - soc_version=soc_version, precision_policy=_enums.PrecisionPolicy.FP16, - optimization_level=0 - )) - torch.jit.save(self.compiled_unet, unet_compiled_path) + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(self.compiled_cat, cat_compiled_path) - else: + if args.use_cache: unet_cache_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile_0.ts") if os.path.exists(unet_cache_compiled_path): self.compiled_unet_cache = torch.jit.load(unet_cache_compiled_path).eval() @@ -269,25 +242,25 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): self.compiled_unet_cache = ( mindietorch.compile(unet_cache, - inputs=[mindietorch.Input((batch_size, - in_channels, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64), - mindietorch.Input((batch_size, - max_position_embeddings, - encoder_hidden_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64)], - allow_tensor_replace_int=True, - require_full_compilation=False, - truncate_long_and_double=True, - soc_version=soc_version, - precision_policy=_enums.PrecisionPolicy.FP16, - optimization_level=0 - )) + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) torch.jit.save(self.compiled_unet_cache, unet_cache_compiled_path) unet_skip_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile_1.ts") @@ -298,30 +271,130 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): self.compiled_unet_skip = ( mindietorch.compile(unet_skip, - inputs=[mindietorch.Input((batch_size, - in_channels, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64), - mindietorch.Input((batch_size, - max_position_embeddings, - encoder_hidden_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64), - mindietorch.Input((batch_size, - 640, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT)], - allow_tensor_replace_int=True, - require_full_compilation=True, - truncate_long_and_double=True, - soc_version=soc_version, - precision_policy=_enums.PrecisionPolicy.FP16, - optimization_level=0 - )) + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + 320, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet_skip, unet_skip_compiled_path) + elif args.use_cache_faster: + unet_cache_faster_compiled_path = os.path.join(self.args.output_dir, + f"unet/unet_bs{batch_size}_aie_compile_0_1.ts") + if os.path.exists(unet_cache_faster_compiled_path): + self.compiled_unet_cache_faster = torch.jit.load(unet_cache_faster_compiled_path).eval() + else: + unet_cache_faster = torch.jit.load( + os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_0_1.pt")).eval() + + self.compiled_unet_cache_faster = ( + mindietorch.compile(unet_cache_faster, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet_cache_faster, unet_cache_faster_compiled_path) + + unet_skip_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile_1_1.ts") + if os.path.exists(unet_skip_compiled_path): + self.compiled_unet_skip = torch.jit.load(unet_skip_compiled_path).eval() + else: + unet_skip = torch.jit.load( + os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_1_1.pt")).eval() + + self.compiled_unet_skip = ( + mindietorch.compile(unet_skip, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + 320, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + 2 * 320, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) torch.jit.save(self.compiled_unet_skip, unet_skip_compiled_path) + else: + unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile.ts") + if os.path.exists(unet_compiled_path): + self.compiled_unet = torch.jit.load(unet_compiled_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}.pt")).eval() + + self.compiled_unet = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet, unet_compiled_path) self.is_init = True @@ -342,8 +415,9 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, - skip_steps = None, + skip_steps=None, flag_cache: int = None, + flag_cache_faster: int = None, **kwargs, ): r""" @@ -448,8 +522,10 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): global vae_time cache = None + cache_faster = None skip_flag = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') cache_flag = torch.zeros([1], dtype=torch.long).to(f'npu:{self.device_0}') + cache_faster_flag = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') stream = mindietorch.npu.Stream(f'npu:{self.device_0}') for i, t in enumerate(self.progress_bar(timesteps)): @@ -469,19 +545,38 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): if skip_steps[i]: with mindietorch.npu.stream(stream): noise_pred = self.compiled_unet_skip(latent_model_input, - t_npu, - text_embeddings_npu, - skip_flag, - cache) + t_npu, + text_embeddings_npu, + skip_flag, + cache) else: with mindietorch.npu.stream(stream): outputs = self.compiled_unet_cache(latent_model_input, - t_npu, - text_embeddings_npu, - cache_flag, - ) + t_npu, + text_embeddings_npu, + cache_flag) noise_pred = outputs[0] cache = outputs[1] + elif flag_cache_faster: + if skip_steps[i]: + with mindietorch.npu.stream(stream): + noise_pred = self.compiled_unet_skip(latent_model_input, + t_npu, + text_embeddings_npu, + skip_flag, + cache_faster_flag, + cache, + cache_faster) + else: + with mindietorch.npu.stream(stream): + outputs = self.compiled_unet_cache_faster(latent_model_input, + t_npu, + text_embeddings_npu, + cache_flag, + cache_faster_flag) + noise_pred = outputs[0] + cache = outputs[1] + cache_faster = outputs[2] else: with mindietorch.npu.stream(stream): noise_pred = self.compiled_unet(latent_model_input, t_npu, text_embeddings_npu) @@ -496,10 +591,10 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): with mindietorch.npu.stream(stream): latents = self.compiled_scheduler( - noise_pred, - t_npu, - latents, - y[None].to(f'npu:{self.device_0}')) + noise_pred, + t_npu, + latents, + y[None].to(f'npu:{self.device_0}')) # call the callback, if provided if callback is not None and i % callback_steps == 0: @@ -529,7 +624,6 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): p3_time += time.time() - start3 return (image, has_nsfw_concept) - def ascendie_infer( self, prompt: Union[str, List[str]], @@ -546,8 +640,9 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, - skip_steps = None, + skip_steps=None, flag_cache: int = None, + flag_cache_faster: int = None, **kwargs, ): # 0. Default height and width to unet @@ -598,9 +693,12 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): global vae_time cache = None + cache_faster = None skip_flag = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') cache_flag = torch.zeros([1], dtype=torch.long).to(f'npu:{self.device_0}') + cache_faster_flag = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') + stream = mindietorch.npu.Stream(f'npu:{self.device_0}') for i, t in enumerate(self.progress_bar(timesteps)): if i == 50: break @@ -621,18 +719,36 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): if flag_cache: if skip_steps[i]: noise_pred = self.compiled_unet_skip(latent_model_input_npu, - t_npu, - text_embeddings_npu, - skip_flag, - cache).to('cpu') + t_npu, + text_embeddings_npu, + skip_flag, + cache).to('cpu') else: outputs = self.compiled_unet_cache(latent_model_input_npu, - t_npu, - text_embeddings_npu, - cache_flag, - ) + t_npu, + text_embeddings_npu, + cache_flag, + ) noise_pred = outputs[0].to('cpu') cache = outputs[1] + elif flag_cache_faster: + if skip_steps[i]: + noise_pred = self.compiled_unet_skip(latent_model_input_npu, + t_npu, + text_embeddings_npu, + skip_flag, + cache_faster_flag, + cache, + cache_faster).to('cpu') + else: + outputs = self.compiled_unet_cache_faster(latent_model_input_npu, + t_npu, + text_embeddings_npu, + cache_flag, + cache_faster_flag) + noise_pred = outputs[0].to('cpu') + cache = outputs[1] + cache_faster = outputs[2] else: with mindietorch.npu.stream(stream): noise_pred = self.compiled_unet(latent_model_input_npu, @@ -650,7 +766,7 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, - **extra_step_kwargs).prev_sample + **extra_step_kwargs)[0] # call the callback, if provided if callback is not None and i % callback_steps == 0: @@ -883,13 +999,17 @@ def parse_arguments(): help="Use cache during inference." ) parser.add_argument( - "--cache_steps", - type=str, + "--use_cache_faster", + action="store_true", + help="Use cache with faster during inference." + ) + parser.add_argument( + "--cache_steps", + type=str, default="1,2,3,4,5,7,9,10,12,13,14,16,18,19,21,23,24,26,27,29,\ - 30,31,33,34,36,37,39,40,41,43,44,45,47,48,49", + 30,31,33,34,36,37,39,40,41,43,44,45,47,48,49", help="Steps to use cache data." ) - return parser.parse_args() @@ -922,6 +1042,14 @@ def main(): continue skip_steps[int(i)] = 1 + flag_cache_faster = 0 + if args.use_cache_faster: + flag_cache_faster = 1 + for i in args.cache_steps.split(','): + if int(i) >= args.steps: + continue + skip_steps[int(i)] = 1 + use_time = 0 prompt_loader = PromptLoader(args.prompt_file, args.prompt_file_type, @@ -953,6 +1081,7 @@ def main(): num_inference_steps=args.steps, skip_steps=skip_steps, flag_cache=flag_cache, + flag_cache_faster=flag_cache_faster, ) else: images = pipe.ascendie_infer( @@ -960,6 +1089,7 @@ def main(): num_inference_steps=args.steps, skip_steps=skip_steps, flag_cache=flag_cache, + flag_cache_faster=flag_cache_faster, ) if i > 4: # do not count the time spent inferring the first 0 to 4 images diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline_parallel.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline_parallel.py index 76c7e606ca..5e64cac270 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline_parallel.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_pipeline_parallel.py @@ -28,6 +28,7 @@ from diffusers import DPMSolverMultistepScheduler, EulerDiscreteScheduler, DDIMS from background_runtime import BackgroundRuntime, RuntimeIOInfo from background_runtime_cache import BackgroundRuntimeCache, RuntimeIOInfoCache +from background_runtime_cache_faster import BackgroundRuntimeCacheFaster, RuntimeIOInfoCacheFaster clip_time = 0 unet_time = 0 @@ -37,6 +38,7 @@ p2_time = 0 p3_time = 0 scheduler_time = 0 + class PromptLoader: def __init__( self, @@ -126,6 +128,7 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): use_parallel_inferencing = False unet_bg = None unet_bg_cache = None + unet_bg_cache_faster = None def parser_args(self, args): self.args = args @@ -164,15 +167,15 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): self.compiled_clip_model = ( mindietorch.compile(model, - inputs=[mindietorch.Input((self.args.batch_size, - max_position_embeddings), - dtype=mindietorch.dtype.INT64)], - allow_tensor_replace_int=True, - require_full_compilation=True, - truncate_long_and_double=True, - precision_policy=_enums.PrecisionPolicy.FP16, - soc_version=soc_version, - optimization_level=0)) + inputs=[mindietorch.Input((self.args.batch_size, + max_position_embeddings), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) torch.jit.save(self.compiled_clip_model, clip_compiled_path) vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_bs{batch_size}_aie_compile.ts") @@ -183,17 +186,17 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): self.compiled_vae_model = ( mindietorch.compile(model, - inputs=[ - mindietorch.Input((self.args.batch_size, in_channels, - sample_size, sample_size), - dtype=mindietorch.dtype.FLOAT)], - allow_tensor_replace_int=True, - require_full_compilation=True, - truncate_long_and_double=True, - soc_version=soc_version, - precision_policy=_enums.PrecisionPolicy.FP16, - optimization_level=0 - )) + inputs=[ + mindietorch.Input((self.args.batch_size, in_channels, + sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) torch.jit.save(self.compiled_vae_model, vae_compiled_path) scheduler_compiled_path = os.path.join(self.args.output_dir, f"ddim/ddim{batch_size}_aie_compile.ts") @@ -204,73 +207,31 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): self.compiled_scheduler = ( mindietorch.compile(model, - inputs=[mindietorch.Input((batch_size, - in_channels, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((batch_size, - in_channels, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64), - mindietorch.Input((batch_size, - in_channels, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64)], - allow_tensor_replace_int=True, - require_full_compilation=True, - truncate_long_and_double=False, - precision_policy=_enums.PrecisionPolicy.FP16, - soc_version=soc_version, - optimization_level=0)) - torch.jit.save(self.compiled_scheduler, scheduler_compiled_path) - - if not args.use_cache: - unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile.ts") - if os.path.exists(unet_compiled_path): - self.compiled_unet = torch.jit.load(unet_compiled_path).eval() - else: - model = torch.jit.load(os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}.pt")).eval() - - self.compiled_unet = ( - mindietorch.compile(model, inputs=[mindietorch.Input((batch_size, - in_channels, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT), + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64), + dtype=mindietorch.dtype.INT64), mindietorch.Input((batch_size, - max_position_embeddings, - encoder_hidden_size), - dtype=mindietorch.dtype.FLOAT)], + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)], allow_tensor_replace_int=True, require_full_compilation=True, - truncate_long_and_double=True, - soc_version=soc_version, + truncate_long_and_double=False, precision_policy=_enums.PrecisionPolicy.FP16, - optimization_level=0 - )) - torch.jit.save(self.compiled_unet, unet_compiled_path) - - runtime_info = RuntimeIOInfo( - input_shapes=[ - (batch_size, in_channels, sample_size, sample_size), - (1,), - (batch_size, max_position_embeddings, encoder_hidden_size) - ], - input_dtypes=[np.float32, np.int64, np.float32], - output_shapes=[(batch_size, in_channels, sample_size, sample_size)], - output_dtypes=[np.float32] - ) - if hasattr(self, 'device_1'): - self.unet_bg = BackgroundRuntime.clone(self.device_1, unet_compiled_path, runtime_info) - self.use_parallel_inferencing = True + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(self.compiled_scheduler, scheduler_compiled_path) - else: + if args.use_cache: unet_cache_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile_0.ts") if os.path.exists(unet_cache_compiled_path): self.compiled_unet_cache = torch.jit.load(unet_cache_compiled_path).eval() @@ -279,25 +240,25 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): self.compiled_unet_cache = ( mindietorch.compile(unet_cache, - inputs=[mindietorch.Input((batch_size, - in_channels, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64), - mindietorch.Input((batch_size, - max_position_embeddings, - encoder_hidden_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64)], - allow_tensor_replace_int=True, - require_full_compilation=False, - truncate_long_and_double=True, - soc_version=soc_version, - precision_policy=_enums.PrecisionPolicy.FP16, - optimization_level=0 - )) + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) torch.jit.save(self.compiled_unet_cache, unet_cache_compiled_path) unet_skip_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile_1.ts") @@ -308,29 +269,29 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): self.compiled_unet_skip = ( mindietorch.compile(unet_skip, - inputs=[mindietorch.Input((batch_size, - in_channels, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64), - mindietorch.Input((batch_size, - max_position_embeddings, - encoder_hidden_size), - dtype=mindietorch.dtype.FLOAT), - mindietorch.Input((1,), - dtype=mindietorch.dtype.INT64), - mindietorch.Input((batch_size, - 640, sample_size, - sample_size), - dtype=mindietorch.dtype.FLOAT)], - allow_tensor_replace_int=True, - require_full_compilation=True, - truncate_long_and_double=True, - soc_version=soc_version, - precision_policy=_enums.PrecisionPolicy.FP16, - optimization_level=0 - )) + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + 320, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) torch.jit.save(self.compiled_unet_skip, unet_skip_compiled_path) runtime_info_cache = RuntimeIOInfoCache( @@ -342,12 +303,152 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): ], input_dtypes=[np.float32, np.int64, np.float32, np.int64], output_shapes=[(batch_size, in_channels, sample_size, sample_size), - (batch_size, 640, sample_size, sample_size)], + (batch_size, 320, sample_size, sample_size)], output_dtypes=[np.float32, np.float32] ) if hasattr(self, 'device_1'): - self.unet_bg_cache = BackgroundRuntimeCache.clone(self.device_1, [unet_cache_compiled_path, unet_skip_compiled_path], runtime_info_cache) + self.unet_bg_cache = BackgroundRuntimeCache.clone(self.device_1, + [unet_cache_compiled_path, unet_skip_compiled_path], + runtime_info_cache) + self.use_parallel_inferencing = True + + elif args.use_cache_faster: + unet_cache_faster_compiled_path = os.path.join(self.args.output_dir, + f"unet/unet_bs{batch_size}_aie_compile_0_1.ts") + if os.path.exists(unet_cache_faster_compiled_path): + self.compiled_unet_cache_faster = torch.jit.load(unet_cache_faster_compiled_path).eval() + else: + unet_cache_faster = torch.jit.load( + os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_0_1.pt")).eval() + + self.compiled_unet_cache_faster = ( + mindietorch.compile(unet_cache_faster, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet_cache_faster, unet_cache_faster_compiled_path) + + unet_skip_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile_1_1.ts") + if os.path.exists(unet_skip_compiled_path): + self.compiled_unet_skip = torch.jit.load(unet_skip_compiled_path).eval() + else: + unet_skip = torch.jit.load( + os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_1_1.pt")).eval() + + self.compiled_unet_skip = ( + mindietorch.compile(unet_skip, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + 320, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + 2 * 320, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet_skip, unet_skip_compiled_path) + + runtime_info_cache_faster = RuntimeIOInfoCacheFaster( + input_shapes=[ + (batch_size, in_channels, sample_size, sample_size), + (1,), + (batch_size, max_position_embeddings, encoder_hidden_size), + (1,), + (1,) + ], + input_dtypes=[np.float32, np.int64, np.float32, np.int64, np.int64], + output_shapes=[(batch_size, in_channels, sample_size, sample_size), + (batch_size, 320, sample_size, sample_size), + (batch_size, 2 * 320, sample_size, sample_size)], + output_dtypes=[np.float32, np.float32, np.float32] + ) + + if hasattr(self, 'device_1'): + self.unet_bg_cache_faster = BackgroundRuntimeCacheFaster.clone(self.device_1, + [unet_cache_faster_compiled_path, + unet_skip_compiled_path], + runtime_info_cache_faster) + self.use_parallel_inferencing = True + + else: + unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile.ts") + if os.path.exists(unet_compiled_path): + self.compiled_unet = torch.jit.load(unet_compiled_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}.pt")).eval() + + self.compiled_unet = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet, unet_compiled_path) + + runtime_info = RuntimeIOInfo( + input_shapes=[ + (batch_size, in_channels, sample_size, sample_size), + (1,), + (batch_size, max_position_embeddings, encoder_hidden_size) + ], + input_dtypes=[np.float32, np.int64, np.float32], + output_shapes=[(batch_size, in_channels, sample_size, sample_size)], + output_dtypes=[np.float32] + ) + if hasattr(self, 'device_1'): + self.unet_bg = BackgroundRuntime.clone(self.device_1, unet_compiled_path, runtime_info) self.use_parallel_inferencing = True self.is_init = True @@ -369,8 +470,9 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, - skip_steps = None, + skip_steps=None, flag_cache: int = None, + flag_cache_faster: int = None, **kwargs, ): r""" @@ -478,8 +580,10 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): text_embeddings, text_embeddings_2 = text_embeddings.chunk(2) cache = None + cache_faster = None skip_flag = torch.ones([1], dtype=torch.long) cache_flag = torch.zeros([1], dtype=torch.long) + cache_faster_flag = torch.zeros([1], dtype=torch.long) stream = mindietorch.npu.Stream(f'npu:{self.device_0}') for i, t in enumerate(self.progress_bar(timesteps)): @@ -502,7 +606,17 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): skip_flag.numpy(), # cache_numpy, ], - skip_steps[i]) + skip_steps[i]) + elif flag_cache_faster: + self.unet_bg_cache_faster.infer_asyn([ + latent_model_input.numpy(), + t[None].numpy().astype(np.int64), + text_embeddings_2.numpy(), + skip_flag.numpy(), + cache_faster_flag.numpy() + # cache_numpy, + ], + skip_steps[i]) else: self.unet_bg.infer_asyn([ latent_model_input.numpy(), @@ -511,7 +625,7 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): ]) latent_model_input_npu = latent_model_input.to(f'npu:{self.device_0}') - t_npu = t[None].to(f'npu:{self.device_0}') + t_npu = t.to(torch.int64)[None].to(f'npu:{self.device_0}') text_embeddings_npu = text_embeddings.to(f'npu:{self.device_0}') start = time.time() @@ -520,18 +634,39 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): with mindietorch.npu.stream(stream): if (skip_steps[i]): noise_pred = self.compiled_unet_skip(latent_model_input_npu, - t_npu, - text_embeddings_npu, - skip_flag.to(f'npu:{self.device_0}'), - cache) + t_npu, + text_embeddings_npu, + skip_flag.to(f'npu:{self.device_0}'), + cache) else: outputs = self.compiled_unet_cache(latent_model_input_npu, - t_npu, - text_embeddings_npu, - cache_flag.to(f'npu:{self.device_0}'), - ) + t_npu, + text_embeddings_npu, + cache_flag.to(f'npu:{self.device_0}'), + ) + noise_pred = outputs[0] + cache = outputs[1] + stream.synchronize() + elif flag_cache_faster: + with mindietorch.npu.stream(stream): + if (skip_steps[i]): + noise_pred = self.compiled_unet_skip(latent_model_input_npu, + t_npu, + text_embeddings_npu, + skip_flag.to(f'npu:{self.device_0}'), + cache_faster_flag.to(f'npu:{self.device_0}'), + cache, + cache_faster) + else: + outputs = self.compiled_unet_cache_faster(latent_model_input_npu, + t_npu, + text_embeddings_npu, + cache_flag.to(f'npu:{self.device_0}'), + cache_faster_flag.to(f'npu:{self.device_0}'), + ) noise_pred = outputs[0] cache = outputs[1] + cache_faster = outputs[2] stream.synchronize() else: with mindietorch.npu.stream(stream): @@ -551,6 +686,12 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): else: out = self.unet_bg_cache.wait_and_get_outputs() noise_pred_text = torch.from_numpy(out[0]) + elif flag_cache_faster: + if (skip_steps[i]): + noise_pred_text = torch.from_numpy(self.unet_bg_cache_faster.wait_and_get_outputs()[0]) + else: + out = self.unet_bg_cache_faster.wait_and_get_outputs() + noise_pred_text = torch.from_numpy(out[0]) else: noise_pred_text = torch.from_numpy(self.unet_bg.wait_and_get_outputs()[0]) else: @@ -560,11 +701,11 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): y = torch.from_numpy(x).long() latents = self.compiled_scheduler( - noise_pred.to(f'npu:{self.device_0}'), - noise_pred_text.to(f'npu:{self.device_0}'), - t[None].to(f'npu:{self.device_0}'), - latents.to(f'npu:{self.device_0}'), - y[None].to(f'npu:{self.device_0}')).to('cpu') + noise_pred.to(f'npu:{self.device_0}'), + noise_pred_text.to(f'npu:{self.device_0}'), + t[None].to(f'npu:{self.device_0}'), + latents.to(f'npu:{self.device_0}'), + y[None].to(f'npu:{self.device_0}')).to('cpu') # call the callback, if provided if callback is not None and i % callback_steps == 0: @@ -592,7 +733,6 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): p3_time += time.time() - start3 return (image, has_nsfw_concept) - def ascendie_infer( self, prompt: Union[str, List[str]], @@ -609,8 +749,9 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, - skip_steps = None, + skip_steps=None, flag_cache: int = None, + flag_cache_faster: int = None, **kwargs, ): # 0. Default height and width to unet @@ -664,8 +805,10 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): text_embeddings, text_embeddings_2 = text_embeddings.chunk(2) cache = None + cache_faster = None skip_flag = torch.ones([1], dtype=torch.long) cache_flag = torch.zeros([1], dtype=torch.long) + cache_faster_flag = torch.zeros([1], dtype=torch.long) for i, t in enumerate(self.progress_bar(timesteps)): if i == 50: @@ -686,8 +829,19 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): t[None].numpy().astype(np.int64), text_embeddings_2.numpy(), skip_flag.numpy(), + # cache_numpy, + ], + skip_steps[i]) + elif flag_cache_faster: + self.unet_bg_cache_faster.infer_asyn([ + latent_model_input.numpy(), + t[None].numpy().astype(np.int64), + text_embeddings_2.numpy(), + skip_flag.numpy(), + cache_faster_flag.numpy() + # cache_numpy, ], - skip_steps[i]) + skip_steps[i]) else: self.unet_bg.infer_asyn([ latent_model_input.numpy(), @@ -704,18 +858,38 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): if flag_cache: if skip_steps[i]: noise_pred = self.compiled_unet_skip(latent_model_input_npu, - t_npu, - text_embeddings_npu, - skip_flag.to(f'npu:{self.device_0}'), - cache).to('cpu') + t_npu, + text_embeddings_npu, + skip_flag.to(f'npu:{self.device_0}'), + cache).to('cpu') else: outputs = self.compiled_unet_cache(latent_model_input_npu, - t_npu, - text_embeddings_npu, - cache_flag.to(f'npu:{self.device_0}'), - ) + t_npu, + text_embeddings_npu, + cache_flag.to(f'npu:{self.device_0}'), + ) noise_pred = outputs[0].to('cpu') cache = outputs[1] + elif flag_cache_faster: + if skip_steps[i]: + with mindietorch.npu.stream(stream): + noise_pred = self.compiled_unet_skip(latent_model_input, + t_npu, + text_embeddings_npu, + skip_flag, + cache_faster_flag, + cache, + cache_faster) + else: + with mindietorch.npu.stream(stream): + outputs = self.compiled_unet_cache_faster(latent_model_input, + t_npu, + text_embeddings_npu, + cache_flag, + cache_faster_flag) + noise_pred = outputs[0] + cache = outputs[1] + cache_faster = outputs[2] else: noise_pred = self.compiled_unet(latent_model_input_npu, t_npu, @@ -733,6 +907,12 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): else: out = self.unet_bg_cache.wait_and_get_outputs() noise_pred_text = torch.from_numpy(out[0]) + elif flag_cache_faster: + if (skip_steps[i]): + noise_pred_text = torch.from_numpy(self.unet_bg_cache_faster.wait_and_get_outputs()[0]) + else: + out = self.unet_bg_cache_faster.wait_and_get_outputs() + noise_pred_text = torch.from_numpy(out[0]) else: noise_pred_text = torch.from_numpy(self.unet_bg.wait_and_get_outputs()[0]) else: @@ -743,7 +923,7 @@ class AIEStableDiffusionPipeline(StableDiffusionPipeline): # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, - **extra_step_kwargs).prev_sample + **extra_step_kwargs)[0] # call the callback, if provided if callback is not None and i % callback_steps == 0: @@ -895,6 +1075,13 @@ def check_device_range_valid(value): def parse_arguments(): parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) parser.add_argument( "-m", "--model", @@ -964,22 +1151,20 @@ def parse_arguments(): help="soc_version.", ) parser.add_argument( - "-o", - "--output_dir", - type=str, - default="./models", - help="Path of directory to save compiled models.", + "--use_cache", + action="store_true", + help="Use cache during inference." ) parser.add_argument( - "--use_cache", + "--use_cache_faster", action="store_true", - help="Use cache during inference." + help="Use cache with faster during inference." ) parser.add_argument( - "--cache_steps", - type=str, + "--cache_steps", + type=str, default="1,2,3,4,5,7,9,10,12,13,14,16,18,19,21,23,24,26,27,29,\ - 30,31,33,34,36,37,39,40,41,43,44,45,47,48,49", + 30,31,33,34,36,37,39,40,41,43,44,45,47,48,49", help="Steps to use cache data." ) @@ -1003,7 +1188,9 @@ def main(): pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) if args.scheduler == "SA-Solver": pipe.scheduler = SASolverScheduler.from_config(pipe.scheduler.config) + pipe.compile_aie_model() + mindietorch.set_device(pipe.device_0) skip_steps = [0] * args.steps @@ -1016,6 +1203,14 @@ def main(): continue skip_steps[int(i)] = 1 + flag_cache_faster = 0 + if args.use_cache_faster: + flag_cache_faster = 1 + for i in args.cache_steps.split(','): + if int(i) >= args.steps: + continue + skip_steps[int(i)] = 1 + use_time = 0 prompt_loader = PromptLoader(args.prompt_file, args.prompt_file_type, @@ -1034,13 +1229,16 @@ def main(): print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") infer_num += args.batch_size - start_time = time.time() + if i > 4: + start_time = time.time() + if args.scheduler == "DDIM": images = pipe.ascendie_infer_ddim( prompts, num_inference_steps=args.steps, skip_steps=skip_steps, flag_cache=flag_cache, + flag_cache_faster=flag_cache_faster, ) else: images = pipe.ascendie_infer( @@ -1048,6 +1246,7 @@ def main(): num_inference_steps=args.steps, skip_steps=skip_steps, flag_cache=flag_cache, + flag_cache_faster=flag_cache_faster, ) if i > 4: # do not count the time spent inferring the first 0 to 4 images @@ -1080,6 +1279,9 @@ def main(): if (pipe.unet_bg_cache): pipe.unet_bg_cache.stop() + if (pipe.unet_bg_cache_faster): + pipe.unet_bg_cache_faster.stop() + # Save image information to a json file if os.path.exists(args.info_file_save_path): os.remove(args.info_file_save_path) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_unet_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_unet_patch.py index 112dd0c88e..bdb74c5f47 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_unet_patch.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/stable_diffusion_unet_patch.py @@ -22,6 +22,7 @@ def main(): assert diffusers_version is not '0.26.3', "expectation diffusers==0.26.3" os.system(f'patch -p0 {diffusers_path[0]}/models/unets/unet_2d_condition.py unet_2d_condition.patch') + os.system(f'patch -p0 {diffusers_path[0]}/models/unets/unet_2d_blocks.py unet_2d_blocks.patch') if __name__ == '__main__': diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/unet_2d_blocks.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/unet_2d_blocks.patch new file mode 100644 index 0000000000..674baeb1f8 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/unet_2d_blocks.patch @@ -0,0 +1,69 @@ +--- ./unet_2d_blocks.py 2024-06-24 09:06:04.593004325 +0800 ++++ ./unet_2d_blocks_wz.py 2024-06-24 09:33:26.052027073 +0800 +@@ -1159,6 +1159,7 @@ class CrossAttnDownBlock2D(nn.Module): + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ++ block_number: Optional[int]=None, + additional_residuals: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () +@@ -1210,6 +1211,8 @@ class CrossAttnDownBlock2D(nn.Module): + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) ++ if block_number is not None and len(output_states) == block_number + 1: ++ return hidden_states, output_states + + if self.downsamplers is not None: + for downsampler in self.downsamplers: +@@ -2364,6 +2367,7 @@ class CrossAttnUpBlock2D(nn.Module): + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ++ block_number: Optional[int]=None, + ) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + is_freeu_enabled = ( +@@ -2372,8 +2376,12 @@ class CrossAttnUpBlock2D(nn.Module): + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) +- +- for resnet, attn in zip(self.resnets, self.attentions): ++ ++ prev_feature = [] ++ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): ++ if block_number is not None and i < len(self.resnets) - block_number - 1: ++ continue ++ + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] +@@ -2390,8 +2398,9 @@ class CrossAttnUpBlock2D(nn.Module): + b2=self.b2, + ) + ++ prev_feature.append(hidden_states) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) +- ++ + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): +@@ -2428,12 +2437,12 @@ class CrossAttnUpBlock2D(nn.Module): + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] +- ++ + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) +- +- return hidden_states ++ ++ return hidden_states, prev_feature + + + class UpBlock2D(nn.Module): diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/unet_2d_condition.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/unet_2d_condition.patch index cc120ce5ce..1e816cec44 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/unet_2d_condition.patch +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion/unet_2d_condition.patch @@ -1,19 +1,37 @@ ---- unet_2d_condition.py 2024-03-09 11:58:33.000000000 +0000 -+++ unet_2d_condition_cache.py 2024-03-09 11:58:35.000000000 +0000 -@@ -855,6 +855,8 @@ +--- ./unet_2d_condition.py 2024-06-24 09:06:04.594004325 +0800 ++++ ./unet_2d_condition_wz.py 2024-06-24 09:33:25.980027072 +0800 +@@ -855,6 +855,10 @@ class UNet2DConditionModel(ModelMixin, C down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, + if_skip: int = 0, -+ inputCache: torch.FloatTensor = None, ++ if_faster: int = 0, ++ inputCache: Optional[torch.FloatTensor] = None, ++ inputFasterCache: Optional[torch.FloatTensor] = None, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. -@@ -1109,30 +1111,57 @@ - ) +@@ -1110,29 +1114,60 @@ class UNet2DConditionModel(ModelMixin, C down_intrablock_additional_residuals = down_block_additional_residuals is_adapter = True -+ + +- down_block_res_samples = (sample,) +- for downsample_block in self.down_blocks: +- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: +- # For t2i-adapter CrossAttnDownBlock2D +- additional_residuals = {} +- if is_adapter and len(down_intrablock_additional_residuals) > 0: +- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) +- +- sample, res_samples = downsample_block( +- hidden_states=sample, +- temb=emb, +- encoder_hidden_states=encoder_hidden_states, +- attention_mask=attention_mask, +- cross_attention_kwargs=cross_attention_kwargs, +- encoder_attention_mask=encoder_attention_mask, +- **additional_residuals, +- ) + if not if_skip: + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: @@ -39,57 +57,43 @@ + + down_block_res_samples += res_samples + else: -+ down_block_res_samples = (sample,) -+ for downsample_block in self.down_blocks: -+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: -+ # For t2i-adapter CrossAttnDownBlock2D -+ additional_residuals = {} -+ if is_adapter and len(down_intrablock_additional_residuals) > 0: -+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) -+ -+ sample, res_samples = downsample_block( -+ hidden_states=sample, -+ temb=emb, -+ encoder_hidden_states=encoder_hidden_states, -+ attention_mask=attention_mask, -+ cross_attention_kwargs=cross_attention_kwargs, -+ encoder_attention_mask=encoder_attention_mask, -+ **additional_residuals, -+ ) -+ else: -+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) -+ if is_adapter and len(down_intrablock_additional_residuals) > 0: -+ sample += down_intrablock_additional_residuals.pop(0) - -- down_block_res_samples = (sample,) -- for downsample_block in self.down_blocks: -- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: -- # For t2i-adapter CrossAttnDownBlock2D -- additional_residuals = {} -- if is_adapter and len(down_intrablock_additional_residuals) > 0: -- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) -- -- sample, res_samples = downsample_block( -- hidden_states=sample, -- temb=emb, -- encoder_hidden_states=encoder_hidden_states, -- attention_mask=attention_mask, -- cross_attention_kwargs=cross_attention_kwargs, -- encoder_attention_mask=encoder_attention_mask, -- **additional_residuals, -- ) -- else: ++ if if_faster: ++ down_block_res_samples = inputFasterCache.chunk(2, dim=1) + else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) - if is_adapter and len(down_intrablock_additional_residuals) > 0: - sample += down_intrablock_additional_residuals.pop(0) -+ down_block_res_samples += res_samples -+ break ++ down_block_res_samples = (sample,) ++ for downsample_block in self.down_blocks: - down_block_res_samples += res_samples ++ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: ++ # For t2i-adapter CrossAttnDownBlock2D ++ additional_residuals = {} ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) ++ ++ sample, res_samples = downsample_block( ++ hidden_states=sample, ++ temb=emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ block_number=0, ++ **additional_residuals, ++ ) ++ else: ++ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ sample += down_intrablock_additional_residuals.pop(0) ++ ++ down_block_res_samples += res_samples ++ break if is_controlnet: new_down_block_res_samples = () -@@ -1146,61 +1175,85 @@ +@@ -1146,61 +1181,87 @@ class UNet2DConditionModel(ModelMixin, C down_block_res_samples = new_down_block_res_samples # 4. mid @@ -172,6 +176,10 @@ - scale=lora_scale, - ) + if not if_skip: ++ if if_faster: ++ inputFasterCache = [tmp.clone() for tmp in down_block_res_samples] ++ inputFasterCache = torch.cat(inputFasterCache[0:2], dim=1) ++ + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + @@ -184,7 +192,7 @@ + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: -+ sample = upsample_block( ++ sample, record_feature = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, @@ -203,16 +211,13 @@ + scale=lora_scale, + ) + -+ if (not if_skip) and (i == 2): -+ inputCache = sample -+ ++ if (not if_skip) and (i == 3): ++ inputCache = record_feature[-2] + else: + for i, upsample_block in enumerate(self.up_blocks): + if i==3: -+ -+ res_samples = down_block_res_samples[-4:-1] -+ -+ sample = upsample_block( ++ res_samples = down_block_res_samples[-2:] ++ sample, _ = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, @@ -221,16 +226,17 @@ + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, ++ block_number=1, + ) # 6. post-process if self.conv_norm_out: -@@ -1215,4 +1268,7 @@ +@@ -1215,4 +1276,7 @@ class UNet2DConditionModel(ModelMixin, C if not return_dict: return (sample,) - return UNet2DConditionOutput(sample=sample) -+ if (not if_skip): -+ return (sample, inputCache) ++ if not if_skip: ++ return (sample, inputCache, inputFasterCache) if if_faster else (sample, inputCache) + else: + return UNet2DConditionOutput(sample=sample) -- Gitee