From fce9356ae89c9b9dd76d32db69620fda24a62d62 Mon Sep 17 00:00:00 2001 From: guowenna Date: Tue, 26 Aug 2025 11:13:56 +0800 Subject: [PATCH 1/2] fix no cache infer and 32g double card infer --- MindIE/MultiModal/Flux.1-DEV/inference_flux.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/MindIE/MultiModal/Flux.1-DEV/inference_flux.py b/MindIE/MultiModal/Flux.1-DEV/inference_flux.py index 87465f8dcd..51fab45dda 100644 --- a/MindIE/MultiModal/Flux.1-DEV/inference_flux.py +++ b/MindIE/MultiModal/Flux.1-DEV/inference_flux.py @@ -187,8 +187,9 @@ def infer(args): T5_model, tensor_parallel={"tp_size": get_world_size()}, ) + T5_model.module.to("cpu") - pipe = FluxPipeline.from_pretrained(args.path, text_encoder_2=T5_model, torch_dtype=torch.bfloat16, local_files_only=True) + pipe = FluxPipeline.from_pretrained(args.path, torch_dtype=torch.bfloat16, local_files_only=True) if args.device_type == "A2-32g-single": torch.npu.set_device(args.device_id) @@ -198,6 +199,8 @@ def infer(args): pipe.to(f"npu:{args.device_id}") else: pipe.to(f"npu:{local_rank}") + pipe.text_encoder_2.to("cpu") + pipe.text_encoder_2 = T5_model.module.to(f"npu:{local_rank}") if args.use_cache: d_stream_config = CacheConfig( @@ -227,6 +230,10 @@ def infer(args): method="dit_block_cache", blocks_count=19, steps_count=args.infer_steps, + step_start=args.infer_steps, + step_interval=2, + block_start=18, #double stream block num - 1 + block_end=18, #double stream block num - 1 ) d_stream_agent = CacheAgent(d_stream_config) pipe.transformer.d_stream_agent = d_stream_agent @@ -234,6 +241,10 @@ def infer(args): method="dit_block_cache", blocks_count=38, steps_count=args.infer_steps, + step_start=args.infer_steps, + step_interval=2, + block_start=37, #single stream block num - 1 + block_end=37, #single stream block num - 1 ) s_stream_agent = CacheAgent(s_stream_config) pipe.transformer.s_stream_agent = s_stream_agent -- Gitee From f19c538724fa1b8b4af433d7e43b0c0c3951c87a Mon Sep 17 00:00:00 2001 From: guowenna Date: Tue, 26 Aug 2025 11:30:10 +0800 Subject: [PATCH 2/2] fix no cache infer and 32g double card infer --- MindIE/MultiModal/Flux.1-DEV/inference_flux.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/MindIE/MultiModal/Flux.1-DEV/inference_flux.py b/MindIE/MultiModal/Flux.1-DEV/inference_flux.py index 51fab45dda..44db5468ac 100644 --- a/MindIE/MultiModal/Flux.1-DEV/inference_flux.py +++ b/MindIE/MultiModal/Flux.1-DEV/inference_flux.py @@ -232,8 +232,8 @@ def infer(args): steps_count=args.infer_steps, step_start=args.infer_steps, step_interval=2, - block_start=18, #double stream block num - 1 - block_end=18, #double stream block num - 1 + block_start=18, + block_end=18, ) d_stream_agent = CacheAgent(d_stream_config) pipe.transformer.d_stream_agent = d_stream_agent @@ -243,8 +243,8 @@ def infer(args): steps_count=args.infer_steps, step_start=args.infer_steps, step_interval=2, - block_start=37, #single stream block num - 1 - block_end=37, #single stream block num - 1 + block_start=37, + block_end=37, ) s_stream_agent = CacheAgent(s_stream_config) pipe.transformer.s_stream_agent = s_stream_agent -- Gitee