From 2528431effb7e5af7d756c8f39dc1881e53f6415 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E8=B6=8A?= Date: Wed, 23 Jul 2025 14:31:31 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0opensoraplan1.0=E8=84=9A?= =?UTF-8?q?=E6=9C=AC=E6=B7=BB=E5=8A=A0=E9=87=8F=E5=8C=96=E9=80=89=E9=A1=B9?= =?UTF-8?q?=EF=BC=8C=E5=B9=B6=E6=89=A7=E8=A1=8Cmindiesd=E9=87=8F=E5=8C=96?= =?UTF-8?q?=E6=96=B9=E6=B3=95=E8=BF=9B=E8=A1=8C=E9=87=8F=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../inference_opensora_plan.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/MindIE/MultiModal/OpenSoraPlan-1.0/inference_opensora_plan.py b/MindIE/MultiModal/OpenSoraPlan-1.0/inference_opensora_plan.py index a9f9d42167..bf139da9e4 100644 --- a/MindIE/MultiModal/OpenSoraPlan-1.0/inference_opensora_plan.py +++ b/MindIE/MultiModal/OpenSoraPlan-1.0/inference_opensora_plan.py @@ -31,6 +31,7 @@ sys.path.append(os.path.split(sys.path[0])[0]) from opensoraplan import OpenSoraPlanPipeline from opensoraplan import compile_pipe, get_scheduler, set_parallel_manager from opensoraplan import CacheConfig, OpenSoraPlanDiTCacheManager +from mindsd.quantization.quantize import quantize MASTER_PORT = '42043' @@ -42,6 +43,13 @@ def main(args): torch.set_grad_enabled(False) device = "npu" if torch.npu.is_available() else "cpu" + if args.type == "float16": + dtype = torch.float16 + elif args.type == "bfloat16": + dtype = torch.bfloat16 + else: + raise ValueError("Unkonwn torch dtype, make sure the '--type' parameter in ['float16', 'bfloat16']!") + sp_size = args.sequence_parallel_size if sp_size == 1: os.environ['RANK'] = '0' @@ -72,6 +80,31 @@ def main(args): # compile pipeline and set the cache_manager and cfg_last_step videogen_pipeline = compile_pipe(videogen_pipeline, cache_manager, args.cfg_last_step) + if args.need_quant: + QuantDict = { + "w4a8_mxfp4": { + "dtype": dtype, + "is_dynamic": True + }, + "w4a4_mxfp4": { + "dtype": dtype, + "is_dynamic": True + }, + } + quant_config = QuantDict.get(args.bit, {}) + + if not quant_config: + raise ValueError(f"Unsupported quantization type, make sure the '--bit' parameter in {list(QuantDict.keys())}!") + + if os.path.exists(args.quant_des_path): + raise ValueError(f"The file of model quantization description file '{args.quant_des_path}' is not exist!") + + videogen_pipeline.transformer = quantize( + videogen_pipeline.transformer, + quant_des_path=args.quant_des_path, + **quant_config + ) + if not os.path.exists(args.save_img_path): os.makedirs(args.save_img_path) @@ -156,7 +189,10 @@ if __name__ == "__main__": parser.add_argument("--cfg_last_step", type=int, default=10000) parser.add_argument("--text_prompt", nargs='+') parser.add_argument('--force_images', action='store_true') - parser.add_argument('--sequence_parallel_size', type=int, default=1) + parser.add_argument('--bit', type=str, default="w4a8_mxfp4", help="w4a8_mxfp4 or w4a4_mxfp4") + parser.add_argument('--type', type=str, default="float16", help="float16 or bfloat16") + parser.add_argument('--quant_des_path', type=str, default="/xxx/xxx.json", help="suggest to use absolute path") + parser.add_argument('--need_quant', type=bool, default=True) args_input = parser.parse_args() if not os.path.exists(args_input.model_path): -- Gitee