diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md index 4a807769b94dc98ee19a0b8941e79c69c344c339..c02b45094d39be8dcb0f4185a81f13ab08e8d62d 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/README.md @@ -67,12 +67,14 @@ ```bash # 安装mindie + source /usr/local/Ascend/ascend-toolkit/set_env.sh chmod +x ./Ascend-mindie_xxx.run ./Ascend-mindie_xxx.run --install source /usr/local/Ascend/mindie/set_env.sh ``` -3. 代码修改 +3. 代码修改(可选) +(1)若需要开启DiTCache、序列压缩等优化,需要执行以下代码修改操作: - 若环境没有patch工具,请自行安装: ```bash apt update @@ -81,6 +83,8 @@ - 执行命令: ```bash python3 attention_patch.py + python3 attention_processor_patch.py + python3 transformer_sd3_patch.py ``` ## 模型推理 @@ -113,23 +117,44 @@ ```bash mkdir ./models ``` + (3) 执行命令查看芯片名称($\{chip\_name\})。 - (3) 执行export命令 + ``` + npu-smi info + #该设备芯片chip_name=310P3 (自行替换) + 回显如下: + +-------------------+-----------------+------------------------------------------------------+ + | NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page) | + | Chip Device | Bus-Id | AICore(%) Memory-Usage(MB) | + +===================+=================+======================================================+ + | 0 310P3 | OK | 15.8 42 0 / 0 | + | 0 0 | 0000:82:00.0 | 0 1074 / 21534 | + +===================+=================+======================================================+ + | 1 310P3 | OK | 15.4 43 0 / 0 | + | 0 1 | 0000:89:00.0 | 0 1070 / 21534 | + +===================+=================+======================================================+ + ``` + + (4) 执行export命令 ```bash - # 800I A2,非并行 - python3 export_model.py --model ${model_base} --output_dir ./models --batch_size 1 --soc Ascend910B4 --device 0 + # 800I A2,非并行,未加DiTCache优化 + python3 export_model.py --model ${model_base} --output_dir ./models --batch_size 1 --soc Ascend${chip_name} --device_type A2 --device 0 + # 800I A2,非并行。开启DiTCache优化 + python3 export_model.py --model ${model_base} --output_dir ./models --batch_size 1 --soc Ascend${chip_name} --device_type A2 --device 0 --use_cache # 300I Duo,并行 - python3 export_model.py --model ${model_base} --output_dir ./models --parallel --batch_size 1 --soc Ascend310P3 --device 0 + python3 export_model.py --model ${model_base} --output_dir ./models --parallel --batch_size 1 --soc Ascend${chip_name} --device_type Duo --device 0 ``` 参数说明: - --model:模型权重路径 - --output_dir: 存放导出模型的路径 - --parallel: 【可选】导出适用于并行方案的模型 - --batch_size: 设置batch_size, 默认值为1, 当前仅支持batch_size=1的场景 - - --soc:只支持Ascend910B4和Ascend310P3。默认为Ascend910B4。 + - --soc:处理器型号。 + - --device_type: 设备形态,当前支持A2、Duo两种形态。 - --device:推理设备ID + - --use_cache:开启DiTCache优化,不配置则不开启 注意:trace+compile耗时较长且占用较多的CPU资源,请勿在执行export命令时运行其他占用CPU内存的任务,避免程序意外退出。 2. 开始推理验证。 @@ -163,7 +188,7 @@ 3. 执行推理脚本。 ```bash - # 不使用unetCache策略 + # 不使用DiTCache,单卡推理,适用800I A2场景 numactl -C 0-23 python3 stable_diffusion3_pipeline.py \ --model ${model_base} \ --prompt_file ./prompts.txt \ @@ -176,7 +201,7 @@ --width 1024 \ --batch_size 1 - # 使用UnetCache策略,同时使用双卡并行策略 + # 不使用DiTCache,使用双卡并行推理,适用300I DUO场景 numactl -C 0-23 python3 stable_diffusion3_pipeline.py \ --model ${model_base} \ --prompt_file ./prompts.txt \ @@ -189,6 +214,20 @@ --width 1024 \ --batch_size 1 \ --parallel + + # 使用DiTCache,单卡推理,适用800I A2场景 + numactl -C 0-23 python3 stable_diffusion3_pipeline_cache.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --prompt_file_type plain \ + --device 0 \ + --save_dir ./results \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 \ + --use_cache ``` 参数说明: @@ -202,6 +241,7 @@ - --device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 - --height:生成图像高度,当前只支持1024 - --width:生成图像宽度,当前只支持1024 + - --use_cache:开启DiTCache优化,不配置则不开启 非并行策略,执行完成后在`./results`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 并行策略,同时使用双卡并行策略,执行完成后在`./results_parallel`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py index 26f23048cf41b09a264e06055bbd8b861be5039a..ea11642ab6fb01e400a34f75de6c94d0824c24a2 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/export_model.py @@ -72,6 +72,7 @@ def parse_arguments() -> Namespace: parser.add_argument("-p", "--parallel", action="store_true", help="Export the unet of bs=1 for parallel inferencing.") parser.add_argument("--soc", help="soc_version.") + parser.add_argument("--device_type", choices=["A2", "Duo"], default="A2", help="device type.") parser.add_argument( "--device", default=0, @@ -379,7 +380,7 @@ def export(args) -> None: if args.use_cache: export_dit_cache(pipeline, args, 0) export_dit_cache(pipeline, args, 1) - if "B" in args.soc: + if args.device_type == "A2": export_dit_cache(pipeline, args, 0, "end") export_dit_cache(pipeline, args, 1, "end") else: diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py index 31ccc987cd8d2e7c54af2ceb150f079ef60fc194..5f980cb3662cb74a5f93038c2304f9fb48e441bf 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/stable_diffusion3_pipeline_cache.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import json import os import time @@ -20,7 +21,7 @@ import torch import mindietorch from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps -from stable_diffusion3_pipeline import PromptLoader, AIEStableDiffusion3Pipeline, parse_arguments +from stable_diffusion3_pipeline import PromptLoader, AIEStableDiffusion3Pipeline tgate = 20 dit_time = 0 @@ -366,6 +367,132 @@ class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): return StableDiffusion3PipelineOutput(images=image) +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-3-medium-diffusers", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="A text file of prompts for generating images.", + ) + parser.add_argument( + "--prompt_file_type", + choices=["plain", "parti", "hpsv2"], + default="plain", + help="Type of prompt file.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--info_file_save_path", + type=str, + default="./image_info.json", + help="Path to save image information file.", + ) + parser.add_argument( + "--steps", + type=int, + default=28, + help="Number of inference steps.", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=0, + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "--num_images_per_prompt", + default=1, + type=int, + help="Number of images generated for each prompt.", + ) + parser.add_argument( + "--max_num_prompts", + default=0, + type=int, + help="Limit the number of prompts (0: no limit).", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) + parser.add_argument( + "--scheduler", + choices=["FlowMatchEuler"], + default="FlowMatchEuler", + help="Type of Sampling methods. Default FlowMatchEuler", + ) + parser.add_argument( + "--height", + default=1024, + type=int, + help="image height", + ) + parser.add_argument( + "--width", + default=1024, + type=int, + help="image width" + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--cache_param", + default="1,2,20,10", + type=str, + help="steps to use cache data" + ) + + return parser.parse_args() + + def main(): args = parse_arguments() save_dir = args.save_dir diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch new file mode 100644 index 0000000000000000000000000000000000000000..bdbbb9c8671d021ed5b4e1669ab88389d03f5ad4 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3.patch @@ -0,0 +1,123 @@ +--- transformer_sd3.py 2024-09-04 09:21:58.280000000 +0000 ++++ transformer_sd3.py 2024-09-04 10:01:47.196000000 +0000 +@@ -97,6 +97,7 @@ + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.inner_dim, + context_pre_only=i == num_layers - 1, ++ layer_idx=i + ) + for i in range(self.config.num_layers) + ] +@@ -106,6 +107,8 @@ + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False ++ self.delta_cache = None ++ self.delta_cache_hidden = None + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: +@@ -245,9 +248,14 @@ + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, ++ cache_dict: torch.LongTensor = None, ++ if_skip: int = 0, ++ delta_cache: torch.FloatTensor = None, ++ delta_cache_hidden: torch.FloatTensor = None, ++ use_cache: bool = False, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ++ ): + """ + The [`SD3Transformer2DModel`] forward method. + +@@ -281,10 +289,6 @@ + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) +- else: +- logger.warning( +- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." +- ) + + height, width = hidden_states.shape[-2:] + +@@ -292,9 +296,8 @@ + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + +- for block in self.transformer_blocks: +- if self.training and self.gradient_checkpointing: +- ++ if self.training and self.gradient_checkpointing: ++ for block in self.transformer_blocks: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: +@@ -312,11 +315,14 @@ + temb, + **ckpt_kwargs, + ) +- +- else: +- encoder_hidden_states, hidden_states = block( +- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb +- ) ++ else: ++ ( ++ (encoder_hidden_states, hidden_states), ++ delta_cache, ++ delta_cache_hidden ++ ) = self.forward_blocks(hidden_states, encoder_hidden_states, temb, ++ use_cache, if_skip, cache_dict, delta_cache, ++ delta_cache_hidden) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) +@@ -339,6 +345,43 @@ + unscale_lora_layers(self, lora_scale) + + if not return_dict: +- return (output,) ++ return (output, delta_cache, delta_cache_hidden) + + return Transformer2DModelOutput(sample=output) ++ ++ def forward_blocks_range(self, hidden_states, encoder_hidden_states, temb, start_idx, end_idx): ++ for block_idx, block in enumerate(self.transformer_blocks[start_idx: end_idx]): ++ encoder_hidden_states, hidden_states = block( ++ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ++ ) ++ ++ return hidden_states, encoder_hidden_states ++ ++ def forward_blocks(self, hidden_states, encoder_hidden_states, temb, use_cache, if_skip, cache_dict, delta_cache, ++ delta_cache_hidden): ++ if not use_cache: ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ 0, len(self.transformer_blocks)) ++ else: ++ # infer [0, cache_start) ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ 0, cache_dict[0]) ++ ++ # infer [cache_start, cache_end) ++ cache_end = cache_dict[0] + cache_dict[2] ++ hidden_states_before_cache = hidden_states.clone() ++ encoder_hidden_states_before_cache = encoder_hidden_states.clone() ++ if not if_skip: ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, ++ temb, cache_dict[0], ++ cache_end) ++ delta_cache = hidden_states - hidden_states_before_cache ++ delta_cache_hidden = encoder_hidden_states - encoder_hidden_states_before_cache ++ else: ++ hidden_states = hidden_states_before_cache + delta_cache ++ encoder_hidden_states = encoder_hidden_states_before_cache + delta_cache_hidden ++ ++ # infer [cache_end, len(self.blocks)) ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ cache_end, len(self.transformer_blocks)) ++ return (encoder_hidden_states, hidden_states), delta_cache, delta_cache_hidden diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..8556cbac2ba1974972fcf3965b78a9f7592b60c8 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_3/transformer_sd3_patch.py @@ -0,0 +1,33 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/transformers/transformer_sd3.py", + "transformer_sd3.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main()