diff --git a/mindspeed_llm/features_manager/__init__.py b/mindspeed_llm/features_manager/__init__.py index f4a9b97e2b53c750606368ee3378d095826c0377..5d9c5f5b3120f67a634a1d3f0a070d3576a96417 100644 --- a/mindspeed_llm/features_manager/__init__.py +++ b/mindspeed_llm/features_manager/__init__.py @@ -70,6 +70,7 @@ from mindspeed_llm.features_manager.transformer.flash_attention.alibi_feature im from mindspeed_llm.features_manager.transformer.mtp import MultiTokenPredictionFeature from mindspeed_llm.features_manager.transformer.multi_latent_attention.mla_feature import MLAFeature from mindspeed_llm.features_manager.transformer.transformer_block import TransformerBlockFeature +from mindspeed_llm.features_manager.ai_framework.ms_patch_feature import MindSporePatchFeature FEATURES_LIST = [ @@ -256,6 +257,12 @@ def add_high_availability_feature(features_list: List[MindSpeedFeature]): ]) +def add_ai_framework_feature(features_list: List[MindSpeedFeature]): + features_list.extend([ + MindSporePatchFeature(), + ]) + + def create_features_list(): features_list = [] add_megatron_basic_features(features_list) @@ -277,6 +284,7 @@ def create_features_list(): add_swap_optimizer_feature(features_list) add_disable_gloo_group_feature(features_list) add_high_availability_feature(features_list) + add_ai_framework_feature(features_list) return features_list diff --git a/mindspeed_llm/features_manager/ai_framework/__init__.py b/mindspeed_llm/features_manager/ai_framework/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mindspeed_llm/features_manager/ai_framework/ms_patch_feature.py b/mindspeed_llm/features_manager/ai_framework/ms_patch_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..65bf5de51cd93e05b8f9c15d214897486ff49323 --- /dev/null +++ b/mindspeed_llm/features_manager/ai_framework/ms_patch_feature.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +from argparse import ArgumentParser +from mindspeed.features_manager.feature import MindSpeedFeature + + +class MindSporePatchFeature(MindSpeedFeature): + def __init__(self): + super().__init__('mindspore-patch', optimization_level=0) + + def register_args(self, parser: ArgumentParser): + group = parser.add_argument_group(title=self.feature_name) + group.add_argument('--ai-framework', type=str, default='pytorch', help='support pytorch and mindspore') + + def use_mindspore(self, args) -> bool: + return hasattr(args, "ai_framework") and args.ai_framework == "mindspore" + + def register_patches(self, patch_manager, args): + if not self.use_mindspore(args): + return + from mindspeed_llm.mindspore.mindspore_adaptor_v2 import mindspore_adaptation + mindspore_adaptation(patch_manager, args) + + def pre_validate_args(self, args): + if not self.use_mindspore(args): + return + from mindspeed_llm.mindspore.mindspore_adaptor_v2 import mindspore_pre_validate_args + mindspore_pre_validate_args(args) + + def validate_args(self, args): + if not self.use_mindspore(args): + return + from mindspeed_llm.mindspore.mindspore_adaptor_v2 import mindspore_validate_args + mindspore_validate_args(args) + + def post_validate_args(self, args): + if not self.use_mindspore(args): + return + from mindspeed_llm.mindspore.mindspore_adaptor_v2 import mindspore_post_validate_args + mindspore_post_validate_args(args) + + def pre_register_patches(self, patch_manager, args): + if not self.use_mindspore(args): + return + from mindspeed_llm.mindspore.mindspore_adaptor_v2 import mindspore_pre_register_patches + mindspore_pre_register_patches(patch_manager, args) diff --git a/mindspeed_llm/mindspore/mindspore_adaptor_v2.py b/mindspeed_llm/mindspore/mindspore_adaptor_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..2127f14ab88bd406a034b9a0db0d64d7624dbd92 --- /dev/null +++ b/mindspeed_llm/mindspore/mindspore_adaptor_v2.py @@ -0,0 +1,24 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +def mindspore_adaptation(patch_manager, args): + pass + + +def pre_validate_args(patch_manager): + pass + + +def mindspore_pre_validate_args(args): + pass + + +def mindspore_validate_args(args): + pass + + +def mindspore_post_validate_args(args): + pass + + +def mindspore_pre_register_patches(manager, args): + pass