diff --git a/mindspeed_llm/mindspore/mindspore_adaptor.py b/mindspeed_llm/mindspore/mindspore_adaptor.py index a6e1e2d7b30119876ceb5f436be5afb071af1a12..549fb6f06e69f2d9c6ebab5dd9631a6e2aa06506 100644 --- a/mindspeed_llm/mindspore/mindspore_adaptor.py +++ b/mindspeed_llm/mindspore/mindspore_adaptor.py @@ -1,6 +1,7 @@ # Copyright (c) 2025; NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved +from mindspeed_llm.mindspore.utils import clear_wrapper from mindspeed_llm.tasks.megatron_adaptor import MegatronAdaptation, MegatronAdaptationABC @@ -341,6 +342,15 @@ class MindSporeAdaptation(MegatronAdaptationABC): MindSporeAdaptation.register('megatron.core.pipeline_parallel.schedules.deallocate_output_tensor', deallocate_output_tensor_) + if args.reuse_fp32_param: + from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer + from mindspeed.mindspore.optimizer.distrib_optimizer import reuse_fp32_param_distrib_optimizer_init_wrapper + target_func = DistributedOptimizer.__init__ + target_func_name = 'megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__' + clear_wrapper(target_func_name, target_func) + MindSporeAdaptation.register(target_func_name, reuse_fp32_param_distrib_optimizer_init_wrapper) + + @staticmethod def reparse_args(): """ diff --git a/mindspeed_llm/mindspore/utils.py b/mindspeed_llm/mindspore/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec7bc1597c475c577424d1f04f7301a5a907f056 --- /dev/null +++ b/mindspeed_llm/mindspore/utils.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General utilities.""" + + +import importlib +import inspect +import sys + +from mindspeed.patch_utils import Patch +from mindspeed_llm.tasks.megatron_adaptor import MegatronAdaptation, MegatronAdaptationABC + + +def clear_wrapper(original_func_name, target_func): + '''update the pt wrapper patch with mindspore wrapper''' + reset_patch(original_func_name) + # orig_func is the original megatron method + orig_func = inspect.unwrap(target_func) + # patch with orig_func, which is equivalent to restore this patch to the original megatron method + apply_patch(original_func_name, orig_func) + + +def reset_patch(original_func_name): + '''clear the wrapper info in Patch object''' + target_patch = MegatronAdaptation._patch_info_collection[original_func_name] + target_patch.wrappers = [] + + +def apply_patch(original_func_name, new_func): + split_name = original_func_name.rsplit('.', 1) + if len(split_name) == 1: + orig_module_name, orig_func_name = original_func_name, None + else: + orig_module_name, orig_func_name = split_name + + orig_module, orig_func = Patch.parse_path(orig_module_name, orig_func_name, False) + final_patch_func = new_func + if orig_func_name is not None: + setattr(orig_module, orig_func_name, final_patch_func) + for _, value in sys.modules.copy().items(): + if orig_func_name is not None and hasattr(value, orig_func_name) \ + and id(getattr(value, orig_func_name)) == id(orig_func): + setattr(value, orig_func_name, final_patch_func) +