diff --git a/mindspeed_llm/mindspore/convert_ckpt.py b/mindspeed_llm/mindspore/convert_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e57b102e7e922658b523c331151a6ca7bc3418 --- /dev/null +++ b/mindspeed_llm/mindspore/convert_ckpt.py @@ -0,0 +1,126 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +import argparse +import importlib +import os +import sys +from functools import wraps +import logging as logger +import mindspore as ms +import torch +import torch.multiprocessing as mp +from mindspeed_llm import megatron_adaptor +import pretrain_gpt +from mindspeed_llm.tasks.posttrain.orm.orm_trainer import ORMTrainer + + +ms.set_context(device_target='CPU', pynative_synchronize=True) +torch.configs.set_pyboost(False) + +MODULE_ROOT = "mindspeed_llm.tasks.checkpoint" + + +def load_plugin(plugin_type, name): + if name == '': + module_name = f"{MODULE_ROOT}.{plugin_type}" + else: + module_name = f"{MODULE_ROOT}.{plugin_type}_{name}" + try: + plugin = importlib.import_module(module_name) + except ModuleNotFoundError: + module_name = f"{MODULE_ROOT}.{name}" + try: + plugin = importlib.import_module(module_name) + except ModuleNotFoundError: + sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.") + + if not hasattr(plugin, 'add_arguments'): + sys.exit(f"{module_name} module is not a plugin. Exiting.") + + logger.info(f"Loaded {module_name} as the {plugin_type}.") + return plugin + + +def main(): + + parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments", + allow_abbrev=False, conflict_handler='resolve') + + parser.add_argument('--model-type', type=str, required=True, + choices=['GPT', 'BERT'], + help='Type of the model') + parser.add_argument('--loader', type=str, default='megatron', + help='Module name to load checkpoint, should be on python path') + parser.add_argument('--load-model-type', type=str, nargs='?', + default=None, const=None, choices=['hf', 'mg', 'optim'], + help='Module name to load checkpoint, should be on python path') + parser.add_argument('--saver', type=str, default='megatron', + help='Module name to save checkpoint, should be on python path') + parser.add_argument('--load-dir', type=str, required=True, + help='Directory to load model checkpoint from') + parser.add_argument('--save-dir', type=str, required=True, + help='Directory to save model checkpoint to') + parser.add_argument('--max-queue-size', type=int, default=50, + help='Maximum number of tensors in the queue') + parser.add_argument('--no-checking', action='store_false', + help='Do not perform checking on the name and ordering of weights', + dest='checking') + parser.add_argument('--spec', type=str, default=None, nargs='*', + help='Specify the pair ' + 'that returns a spec to customize transformer layer, depending on the use case.') + parser.add_argument('--model-type-hf', type=str, default="llama2", + choices=['baichuan', 'baichuan2', 'llama2', 'mixtral', 'chatglm3', 'gemma', 'gemma2', 'qwen3', + 'bloom', 'bloom_3b', 'qwen', 'internlm2', 'deepseek2', 'minicpm', 'minicpm3', 'minicpm-moe', + 'deepseek2-lite', 'qwen2-moe', 'qwen3-moe', 'phi3.5', 'phi3.5-moe', 'hunyuan', 'glm4'], + help='model type of huggingface') + parser.add_argument('--ckpt-cfg-path', type=str, default="configs/checkpoint/model_cfg.json", + help="Path to the config directory. If not specified, the default path in the repository will be used.") + parser.add_argument('--qlora-nf4', action='store_true', + help='use bitsandbytes nf4 to quantize model.') + parser.add_argument('--orm', action="store_true", default=False, + help='Specify the ORM ckpt conversion, convert additional rm_head layer in ORM.') + parser.add_argument('--save-lora-to-hf', action='store_true', default=False, + help='Enable only save lora-checkpoint to hf') + parser.add_argument('--load-checkpoint-loosely', action='store_true', default=False, + help='Enable loading checkpoint not strictly.') + known_args, _ = parser.parse_known_args() + + + if known_args.load_model_type == 'optim': + loader = load_plugin('loader', known_args.load_model_type) + loader.add_arguments(parser) + args = parser.parse_args() + model_provider = pretrain_gpt.model_provider + loader.load_checkpoint(model_provider, args) + else: + use_saver = known_args.load_model_type is None + if use_saver: + loader = load_plugin('loader', known_args.loader) + saver = load_plugin('saver', known_args.saver) + else: + loader = load_plugin('loader', known_args.load_model_type) + saver = load_plugin('saver', '') + + loader.add_arguments(parser) + saver.add_arguments(parser) + + args = parser.parse_args() + + queue = mp.Queue(maxsize=args.max_queue_size) + model_provider = ORMTrainer.model_provider if args.orm else pretrain_gpt.model_provider + if args.orm and not args.use_mcore_models: + raise AssertionError("Currently Outcome Reward Model only support Mcore models") + + logger.info("Starting saver...") + saver_proc = mp.Process(target=saver.save_model_checkpoint, args=(model_provider, queue, args)) + saver_proc.start() + + logger.info("Starting loader...") + loader.load_checkpoint(model_provider, queue, args) + + logger.info("Waiting for saver to complete...") + saver_proc.join() + + +if __name__ == '__main__': + main() + diff --git a/mindspeed_llm/mindspore/core/datasets/__init__.py b/mindspeed_llm/mindspore/core/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mindspeed_llm/mindspore/core/datasets/blended_megatron_dataset_builder.py b/mindspeed_llm/mindspore/core/datasets/blended_megatron_dataset_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..33687a5953c9e12fa24c44cf412b40f15a3e4564 --- /dev/null +++ b/mindspeed_llm/mindspore/core/datasets/blended_megatron_dataset_builder.py @@ -0,0 +1,22 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +import torch + +from megatron.training import get_args +from megatron.core import mpu +from mindspore.communication import get_local_rank + + +def need_to_build_dataset(): + args = get_args() + share_save = not args.no_shared_storage + rank = torch.distributed.get_rank() + if share_save: + return rank == 0 + gpus_per_node = torch.cuda.device_count() + current_rank = get_local_rank() + if args.tensor_model_parallel_size > gpus_per_node: + return mpu.get_tensor_model_parallel_rank() == 0 + return mpu.get_tensor_model_parallel_rank() == 0 and current_rank % gpus_per_node == 0 + diff --git a/mindspeed_llm/mindspore/mindspore_adaptor.py b/mindspeed_llm/mindspore/mindspore_adaptor.py index a140ee1713ab74959ab10a441dd2f058d2a37621..2bb3820b5009393a2001c050057b5c7f03bbdd89 100644 --- a/mindspeed_llm/mindspore/mindspore_adaptor.py +++ b/mindspeed_llm/mindspore/mindspore_adaptor.py @@ -1,6 +1,7 @@ # Copyright (c) 2025; NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved +from mindspeed_llm.mindspore.utils import clear_wrapper from mindspeed_llm.tasks.megatron_adaptor import MegatronAdaptation, MegatronAdaptationABC @@ -344,6 +345,35 @@ class MindSporeAdaptation(MegatronAdaptationABC): from mindspeed_llm.mindspore.core.transformer.dot_product_attention import flash_attention_forward MindSporeAdaptation.register('mindspeed_llm.core.transformer.dot_product_attention.flash_attention_forward', flash_attention_forward) + if args.reuse_fp32_param: + from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer + from mindspeed.mindspore.optimizer.distrib_optimizer import reuse_fp32_param_distrib_optimizer_init_wrapper + target_func = DistributedOptimizer.__init__ + target_func_name = 'megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__' + clear_wrapper(target_func_name, target_func) + MindSporeAdaptation.register(target_func_name, reuse_fp32_param_distrib_optimizer_init_wrapper) + + from mindspeed_llm.mindspore.core.datasets.blended_megatron_dataset_builder import need_to_build_dataset + MindSporeAdaptation.register( + 'mindspeed_llm.core.datasets.blended_megatron_dataset_builder.need_to_build_dataset', + need_to_build_dataset) + from mindspeed.mindspore.ops.npu_rotary_position_embedding import npu_rotary_position_embedding + MindSporeAdaptation.register( + 'mindspeed.ops.npu_rotary_position_embedding.npu_rotary_position_embedding', + npu_rotary_position_embedding) + + from mindspeed_llm.mindspore.tasks.checkpoint.models import register_functions, get_modules_from_pretrained + MindSporeAdaptation.register( + 'mindspeed_llm.tasks.checkpoint.models.ModelBase._ModelBase__register_functions', register_functions) + MindSporeAdaptation.register( + 'mindspeed_llm.tasks.checkpoint.models.HuggingfaceModel.get_modules_from_pretrained', + get_modules_from_pretrained) + + from mindspeed.mindspore.legacy.model.module import fp32_to_float16, float16_to_fp32 + MindSporeAdaptation.register('megatron.legacy.model.module.fp32_to_float16', fp32_to_float16) + MindSporeAdaptation.register('megatron.legacy.model.module.float16_to_fp32', float16_to_fp32) + + @staticmethod def reparse_args(): """ diff --git a/mindspeed_llm/mindspore/tasks/checkpoint/__init__.py b/mindspeed_llm/mindspore/tasks/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mindspeed_llm/mindspore/tasks/checkpoint/models.py b/mindspeed_llm/mindspore/tasks/checkpoint/models.py new file mode 100644 index 0000000000000000000000000000000000000000..65c18c3b6d88622cf6b8da8d0255f4862a995720 --- /dev/null +++ b/mindspeed_llm/mindspore/tasks/checkpoint/models.py @@ -0,0 +1,140 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +import abc +import os +import sys +import re +from tqdm import tqdm +import torch +from transformers import AutoModelForCausalLM, AutoConfig, AutoModelForSequenceClassification +from peft import get_peft_model, LoraConfig, TaskType +from mindspeed_llm.tasks.checkpoint.models import ModelBase + + +def register_functions(self): + self.get_module_mapping() + + def _get_obj(self, value, **kwargs): + pattern = r'(\w+)(?:\[(\w+)\])?' + matches = re.findall(pattern, value) + self.update_kwargs_idx(**kwargs) + obj = self.get_model_item(**kwargs) + for attr, attr_ident in matches: + if hasattr(obj, attr): + obj = getattr(obj, attr) + else: + return None + if attr_ident: + if attr_ident in self.kwargs_idx: + attr_idx = self.kwargs_idx[attr_ident] + obj = obj[attr_idx] + else: + raise AssertionError(f"check {self.__class__.__name__}.module_mapping **{attr_ident}**.") + return obj + + def _get_dst_obj(self, value, **kwargs): + if kwargs.get("layer_idx") is None: + kwargs["layer_idx"] = kwargs.get("dst_layer_idx") + + return _get_obj(self, value, **kwargs) + + def _get_src_obj(self, value, **kwargs): + if kwargs.get("layer_idx") is None: + kwargs["layer_idx"] = kwargs.get("src_layer_idx") + + return _get_obj(self, value, **kwargs) + + def _func_generator_get_module(value): + def func(self, **kwargs): + return _get_src_obj(self, value, **kwargs) + return func + + def _func_generator_get_weight(value): + def func(self, **kwargs): + return _get_src_obj(self, value, **kwargs).weight.data + return func + + def _func_generator_get_bias(value): + def func(self, **kwargs): + return _get_src_obj(self, value, **kwargs).bias.data + return func + + def _func_generator_set_weight(value): + def func(self, **kwargs): + set_tensor = _get_dst_obj(self, value, **kwargs) + data = kwargs.get('data') + if data.dtype != set_tensor.weight.dtype: + data = data.to(dtype=set_tensor.weight.dtype) + set_tensor.weight.data = data + return set_tensor.weight.data + + return func + + def _func_generator_set_module(value): + def func(self, **kwargs): + return _get_dst_obj(self, value, **kwargs).data.copy_(kwargs.get('data')) + return func + + def _func_generator_set_bias(value): + def func(self, **kwargs): + set_tensor = _get_dst_obj(self, value, **kwargs) + data = kwargs.get('data') + if data.dtype != set_tensor.weight.dtype: + data = data.to(dtype=set_tensor.weight.dtype) + set_tensor.bias.data = data + return set_tensor.bias.data + + return func + + def _func_generator_has_module(value): + def func(self, **kwargs): + obj = _get_src_obj(self, value, **kwargs) + return True if obj else False + return func + + def _func_generator_has_bias(value): + def func(self, **kwargs): + bias = getattr(_get_src_obj(self, value, **kwargs), 'bias', None) + return bias is not None + return func + + if self.module_mapping: + for key, value in self.module_mapping.items(): + setattr(self, "get_" + key + "_module", _func_generator_get_module(value).__get__(self, ModelBase)) + setattr(self, "set_" + key + "_module", _func_generator_set_module(value).__get__(self, ModelBase)) + setattr(self, "get_" + key + "_weight", _func_generator_get_weight(value).__get__(self, ModelBase)) + setattr(self, "get_" + key + "_bias", _func_generator_get_bias(value).__get__(self, ModelBase)) + setattr(self, "set_" + key + "_weight", _func_generator_set_weight(value).__get__(self, ModelBase)) + setattr(self, "set_" + key + "_bias", _func_generator_set_bias(value).__get__(self, ModelBase)) + setattr(self, "has_" + key + "_module", _func_generator_has_module(value).__get__(self, ModelBase)) + setattr(self, "has_" + key + "_bias", _func_generator_has_bias(value).__get__(self, ModelBase)) + + +def get_modules_from_pretrained(self, device_map="cpu", trust_remote_code=True): + # Load Huggingface model. + if self.args_cmd.save_model_type == "hf": + load_dir = self.args_cmd.save_dir + else: + load_dir = self.args_cmd.load_dir + if self.args_cmd.orm: + self.module = [AutoModelForSequenceClassification.from_pretrained( + load_dir, device_map=device_map, trust_remote_code=trust_remote_code, local_files_only=True, + num_labels=1 + )] + else: + self.module = [AutoModelForCausalLM.from_pretrained( + load_dir, trust_remote_code=trust_remote_code, local_files_only=True, low_cpu_mem_usage=False + )] + + if self.args_cmd.save_lora_to_hf: + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=self.args_cmd.lora_r, + lora_alpha=self.args_cmd.lora_alpha, + target_modules=self.target_lora_modules_hf, + lora_dropout=0.0, + bias="none" + ) + self.module = [get_peft_model(self.module[0], lora_config)] + if hasattr(self.args, "torch_dtype") and self.args.torch_dtype in ["float16", "bfloat16"]: + self.module[0] = self.module[0].to(eval(f'torch.{self.args.torch_dtype}')) diff --git a/mindspeed_llm/mindspore/tasks/models/transformer/multi_head_latent_attention.py b/mindspeed_llm/mindspore/tasks/models/transformer/multi_head_latent_attention.py index b0e9065abe3cc48d8b23629fab0787fd791f39d6..1ae8598ad92ab8604e135591d7532dd4c239288f 100644 --- a/mindspeed_llm/mindspore/tasks/models/transformer/multi_head_latent_attention.py +++ b/mindspeed_llm/mindspore/tasks/models/transformer/multi_head_latent_attention.py @@ -259,5 +259,3 @@ def LinearNoTP_forward(self, input_): output = torch.matmul(input_, self.weight.t()) output = output.view(bs, seq_len, self.output_size) return output - - \ No newline at end of file diff --git a/mindspeed_llm/mindspore/utils.py b/mindspeed_llm/mindspore/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec7bc1597c475c577424d1f04f7301a5a907f056 --- /dev/null +++ b/mindspeed_llm/mindspore/utils.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General utilities.""" + + +import importlib +import inspect +import sys + +from mindspeed.patch_utils import Patch +from mindspeed_llm.tasks.megatron_adaptor import MegatronAdaptation, MegatronAdaptationABC + + +def clear_wrapper(original_func_name, target_func): + '''update the pt wrapper patch with mindspore wrapper''' + reset_patch(original_func_name) + # orig_func is the original megatron method + orig_func = inspect.unwrap(target_func) + # patch with orig_func, which is equivalent to restore this patch to the original megatron method + apply_patch(original_func_name, orig_func) + + +def reset_patch(original_func_name): + '''clear the wrapper info in Patch object''' + target_patch = MegatronAdaptation._patch_info_collection[original_func_name] + target_patch.wrappers = [] + + +def apply_patch(original_func_name, new_func): + split_name = original_func_name.rsplit('.', 1) + if len(split_name) == 1: + orig_module_name, orig_func_name = original_func_name, None + else: + orig_module_name, orig_func_name = split_name + + orig_module, orig_func = Patch.parse_path(orig_module_name, orig_func_name, False) + final_patch_func = new_func + if orig_func_name is not None: + setattr(orig_module, orig_func_name, final_patch_func) + for _, value in sys.modules.copy().items(): + if orig_func_name is not None and hasattr(value, orig_func_name) \ + and id(getattr(value, orig_func_name)) == id(orig_func): + setattr(value, orig_func_name, final_patch_func) +