From 49ed9d084af718ba2e1c5909f8e089c36fcf2d92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=B0=91=E9=B9=8F?= Date: Fri, 25 Apr 2025 17:38:47 +0800 Subject: [PATCH 1/2] =?UTF-8?q?flux=E5=8F=8C=E5=8D=A1=E6=98=BE=E5=AD=98?= =?UTF-8?q?=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MindIE/MultiModal/Flux.1-DEV/README.md | 2 +- MindIE/MultiModal/Flux.1-DEV/inference_flux.py | 18 ++++++++++++++---- MindIE/MultiModal/Flux.1-DEV/requirements.txt | 3 ++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/MindIE/MultiModal/Flux.1-DEV/README.md b/MindIE/MultiModal/Flux.1-DEV/README.md index a182483f1f..428d666a02 100644 --- a/MindIE/MultiModal/Flux.1-DEV/README.md +++ b/MindIE/MultiModal/Flux.1-DEV/README.md @@ -197,7 +197,7 @@ python3 tpsplit_weight.py --path ${model_path} ``` 3.执行命令运行Flux: ```shell -export ASCEND_LAUNCH_BLOCKING = 1 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True ASCEND_RT_VISIBLE_DEVICES=0,1 torchrun --master_port=2002 --nproc_per_node=2 inference_flux.py --device_type "A2-32g-dual" --path ${model_path} --prompt_path "./prompts.txt" --width 1024 --height 1024 --infer_steps 50 --seed 42 --use_cache ``` 参数说明: diff --git a/MindIE/MultiModal/Flux.1-DEV/inference_flux.py b/MindIE/MultiModal/Flux.1-DEV/inference_flux.py index b6ad385676..7952a0404a 100644 --- a/MindIE/MultiModal/Flux.1-DEV/inference_flux.py +++ b/MindIE/MultiModal/Flux.1-DEV/inference_flux.py @@ -28,6 +28,7 @@ from mindiesd import CacheAgent, CacheConfig from FLUX1dev import FluxPipeline from FLUX1dev import get_local_rank, get_world_size, initialize_torch_distributed from FLUX1dev.utils import check_prompts_valid, check_param_valid, check_dir_safety, check_file_safety +from transformers import T5EncoderModel torch_npu.npu.set_compile_mode(jit_compile=False) @@ -175,7 +176,19 @@ def infer(args): FluxPipeline.extract_init_dict = classmethod(replace_tp_extract_init_dict) check_dir_safety(args.path) - pipe = FluxPipeline.from_pretrained(args.path, torch_dtype=torch.bfloat16, local_files_only=True) + T5_model_path = os.path.join(args.path, "text_encoder_2") + T5_model = T5EncoderModel.from_pretrained(T5_model_path).to(torch.bfloat16) + if args.device_type == "A2-32g-dual": + local_rank = get_local_rank() + world_size = get_world_size() + initialize_torch_distributed(local_rank, world_size) + import deepspeed + T5_model = deepspeed.init_inference( + T5_model, + tensor_parallel={"tp_size": get_world_size()}, + ) + + pipe = FluxPipeline.from_pretrained(args.path, text_encoder_2=T5_model,torch_dtype=torch.bfloat16, local_files_only=True) if args.device_type == "A2-32g-single": torch.npu.set_device(args.device_id) @@ -184,9 +197,6 @@ def infer(args): torch.npu.set_device(args.device_id) pipe.to(f"npu:{args.device_id}") else: - local_rank = get_local_rank() - world_size = get_world_size() - initialize_torch_distributed(local_rank, world_size) pipe.to(f"npu:{local_rank}") if args.use_cache: diff --git a/MindIE/MultiModal/Flux.1-DEV/requirements.txt b/MindIE/MultiModal/Flux.1-DEV/requirements.txt index 7ab1879205..370cb323fd 100644 --- a/MindIE/MultiModal/Flux.1-DEV/requirements.txt +++ b/MindIE/MultiModal/Flux.1-DEV/requirements.txt @@ -6,4 +6,5 @@ diffusers==0.32.1 transformers==4.46.3 tensorboard Jinja2 -peft==0.11.1 \ No newline at end of file +peft==0.11.1 +deepspeed \ No newline at end of file -- Gitee From 4cbb79d99de9dca52c73466fb955070644f344bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=B0=91=E9=B9=8F?= Date: Fri, 25 Apr 2025 17:48:16 +0800 Subject: [PATCH 2/2] =?UTF-8?q?flux=E5=8F=8C=E5=8D=A1=E6=98=BE=E5=AD=98?= =?UTF-8?q?=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MindIE/MultiModal/Flux.1-DEV/inference_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MultiModal/Flux.1-DEV/inference_flux.py b/MindIE/MultiModal/Flux.1-DEV/inference_flux.py index 7952a0404a..87465f8dcd 100644 --- a/MindIE/MultiModal/Flux.1-DEV/inference_flux.py +++ b/MindIE/MultiModal/Flux.1-DEV/inference_flux.py @@ -188,7 +188,7 @@ def infer(args): tensor_parallel={"tp_size": get_world_size()}, ) - pipe = FluxPipeline.from_pretrained(args.path, text_encoder_2=T5_model,torch_dtype=torch.bfloat16, local_files_only=True) + pipe = FluxPipeline.from_pretrained(args.path, text_encoder_2=T5_model, torch_dtype=torch.bfloat16, local_files_only=True) if args.device_type == "A2-32g-single": torch.npu.set_device(args.device_id) -- Gitee