From d31851217e4bbc3d48745b0c3d7e8278133c1c3b Mon Sep 17 00:00:00 2001 From: sunjunjie Date: Thu, 26 Jun 2025 17:36:47 +0800 Subject: [PATCH] add checkpoint conversion 2.0 --- configs/checkpoint/model_cfg.json | 44 + convert_ckpt_v2.py | 69 ++ .../mcore/deepseek3/convert_ckpt_deepseek3.py | 4 +- .../convert_ckpt_deepseek3_mcore2hf.py | 4 +- mindspeed_llm/tasks/checkpoint/convert.py | 152 ++++ .../tasks/checkpoint/convert_hf2mg.py | 796 ++++++++++++++++++ .../tasks/checkpoint/model_builder.py | 305 +++++++ 7 files changed, 1370 insertions(+), 4 deletions(-) create mode 100644 convert_ckpt_v2.py create mode 100644 mindspeed_llm/tasks/checkpoint/convert.py create mode 100644 mindspeed_llm/tasks/checkpoint/convert_hf2mg.py create mode 100644 mindspeed_llm/tasks/checkpoint/model_builder.py diff --git a/configs/checkpoint/model_cfg.json b/configs/checkpoint/model_cfg.json index 458d639c0..a07dcbc70 100644 --- a/configs/checkpoint/model_cfg.json +++ b/configs/checkpoint/model_cfg.json @@ -593,6 +593,50 @@ "final_layernorm": "model.norm", "output_layer": "lm_head" } + }, + "deepseek3": { + "__base__": "base", + "config_set_value": { + "qkv_type": "pack_mla", + "multi_head_latent_attention": true, + "qk_layernorm": true, + "router_bias": true + }, + "config_hf_key_mapping": { + "first_k_dense_replace": "first_k_dense_replace", + "kv_lora_rank": "kv_lora_rank", + "moe_intermediate_size": "moe_intermediate_size", + "moe_layer_freq": "moe_layer_freq", + "num_experts": "n_routed_experts", + "n_shared_experts": "n_shared_experts", + "q_lora_rank": "q_lora_rank", + "qk_nope_head_dim": "qk_nope_head_dim", + "qk_rope_head_dim": "qk_rope_head_dim" + }, + "model_hf_key_mapping": { + "layers_self_attention_linear_q_proj": "model.layers[layer_idx].self_attn.q_a_proj", + "layers_self_attention_linear_kv_proj": "model.layers[layer_idx].self_attn.kv_a_proj_with_mqa", + "layers_self_attention_linear_proj": "model.layers[layer_idx].self_attn.o_proj", + "layers_self_attention_linear_qb": "model.layers[layer_idx].self_attn.q_b_proj", + "layers_self_attention_linear_kvb": "model.layers[layer_idx].self_attn.kv_b_proj", + "layers_self_attention_q_layernorm": "model.layers[layer_idx].self_attn.q_a_layernorm", + "layers_self_attention_k_layernorm": "model.layers[layer_idx].self_attn.kv_a_layernorm", + "layers_mlp_router": "model.layers[layer_idx].mlp.gate", + "layers_mlp_router_bias": "model.layers[layer_idx].mlp.gate.e_score_correction_bias", + "layers_mlp_experts_gate_proj": "model.layers[layer_idx].mlp.experts[expert_idx].gate_proj", + "layers_mlp_experts_up_proj": "model.layers[layer_idx].mlp.experts[expert_idx].up_proj", + "layers_mlp_experts_linear_fc2": "model.layers[layer_idx].mlp.experts[expert_idx].down_proj", + "layers_mlp_shared_experts_gate_proj": "model.layers[layer_idx].mlp.shared_experts.gate_proj", + "layers_mlp_shared_experts_up_proj": "model.layers[layer_idx].mlp.shared_experts.up_proj", + "layers_mlp_shared_experts_linear_fc2": "model.layers[layer_idx].mlp.shared_experts.down_proj", + "final_layernorm": "model.norm", + "output_layer": "lm_head", + "mtp_layers_enorm": "model.layers[layer_idx].enorm", + "mtp_layers_hnorm": "model.layers[layer_idx]}.hnorm", + "mtp_layers_eh_proj": "model.layers[layer_idx].eh_proj", + "mtp_layers_embed_tokens": "model.layers[layer_idx].embed_tokens", + "mtp_layers_shared_head_norm": "model.layers[layer_idx].shared_head.norm" + } } } } diff --git a/convert_ckpt_v2.py b/convert_ckpt_v2.py new file mode 100644 index 000000000..c91bd9484 --- /dev/null +++ b/convert_ckpt_v2.py @@ -0,0 +1,69 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import argparse +import logging as logger +import time +from mindspeed_llm.tasks.checkpoint.convert_hf2mg import Hf2MgConvert + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--load-model-type', type=str, nargs='?', + default='hf', const=None, choices=['hf'], + help='Type of the converter') + parser.add_argument('--save-model-type', type=str, default='mg', + choices=['mg'], help='Save model type') + parser.add_argument('--load-dir', type=str, required=True, + help='Directory to load model checkpoint from') + parser.add_argument('--save-dir', type=str, required=True, + help='Directory to save model checkpoint to') + parser.add_argument('--model-type-hf', type=str, default="qwen3", + choices=['qwen3', 'qwen3-moe', 'deepseek3'], + help='model type of huggingface') + parser.add_argument('--target-tensor-parallel-size', type=int, default=1, + help='Target tensor model parallel size, defaults to 1.') + parser.add_argument('--target-pipeline-parallel-size', type=int, default=1, + help='Target pipeline model parallel size, defaults to 1.') + parser.add_argument('--target-expert-parallel-size', type=int, default=1, + help='Target expert model parallel size, defaults to 1.') + parser.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, + help='Number of layers per virtual pipeline stage') + parser.add_argument('--moe-grouped-gemm', action='store_true', + help='Use moe grouped gemm.') + parser.add_argument("--noop-layers", type=str, default=None, help='Specity the noop layers.') + parser.add_argument('--mtp-num-layers', type=int, default=0, help='Multi-Token prediction layer num') + parser.add_argument('--num-layer-list', type=str, + help='a list of number of layers, separated by comma; e.g., 4,4,4,4') + parser.add_argument('--first-k-dense-replace', type=int, default=0, + help='Customizing the number of dense layers.') + parser.add_argument("--moe-tp-extend-ep", action='store_true', + help="use tp group to extend experts parallism instead of sharding weight tensor of experts in tp group") + parser.add_argument('--mla-mm-split', action='store_true', default=False, + help='Split 2 up-proj matmul into 4 in MLA') + parser.add_argument("--shared-expert-gate", action='store_true', + help="moe model has shared expert gate") + parser.add_argument('--schedules-method', type=str, default=None, choices=['dualpipev'], + help='An innovative bidirectional pipeline parallelism algorithm.') + parser.add_argument('--qlora-nf4', action='store_true', + help='use bitsandbytes nf4 to quantize model.') + + args, _ = parser.parse_known_args() + return args + + +def main(): + args = get_args() + logger.info(f"Arguments: {args}") + + if args.load_model_type == 'hf' and args.save_model_type == 'mg': + converter = Hf2MgConvert(args) + else: + raise "This conversion scheme is not supported" + + start_time = time.time() + converter.run() + end_time = time.time() + logger.info("time-consuming: {:.2f}s".format(end_time - start_time)) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/mcore/deepseek3/convert_ckpt_deepseek3.py b/examples/mcore/deepseek3/convert_ckpt_deepseek3.py index 6603750b8..db5521b74 100644 --- a/examples/mcore/deepseek3/convert_ckpt_deepseek3.py +++ b/examples/mcore/deepseek3/convert_ckpt_deepseek3.py @@ -874,11 +874,11 @@ def get_args(): parser.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, help='Number of layers per virtual pipeline stage') parser.add_argument('--moe-grouped-gemm', action='store_true', - help='Usr moe grouped gemm.') + help='Use moe grouped gemm.') parser.add_argument("--noop-layers", type=str, default=None, help='Specity the noop layers.') parser.add_argument('--mtp-num-layers', type=int, default=0, help='Multi-Token prediction layer num') parser.add_argument('--num-layer-list', type=str, - help='a list of number of layers, seperated by comma; e.g., 4,4,4,4') + help='a list of number of layers, separated by comma; e.g., 4,4,4,4') parser.add_argument('--num-layers', type=int, default=61, help='Number of transformer layers.') parser.add_argument('--first-k-dense-replace', type=int, default=3, diff --git a/examples/mcore/deepseek3/convert_ckpt_deepseek3_mcore2hf.py b/examples/mcore/deepseek3/convert_ckpt_deepseek3_mcore2hf.py index 24a86bb37..a94c32971 100644 --- a/examples/mcore/deepseek3/convert_ckpt_deepseek3_mcore2hf.py +++ b/examples/mcore/deepseek3/convert_ckpt_deepseek3_mcore2hf.py @@ -996,11 +996,11 @@ def get_args(): help='Source expert model parallel size, default to 1') parser.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, help='Number of layers per virtual pipeline stage') - parser.add_argument('--moe-grouped-gemm', action='store_true', help='Usr moe grouped gemm.') + parser.add_argument('--moe-grouped-gemm', action='store_true', help='Use moe grouped gemm.') parser.add_argument("--noop-layers", type=str, default=None, help='Specity the noop layers.') parser.add_argument('--mtp-num-layers', type=int, default=0, help='Multi-Token prediction layer num') parser.add_argument('--num-layer-list', type=str, - help='a list of number of layers, seperated by comma; e.g., 4,4,4,4') + help='a list of number of layers, separated by comma; e.g., 4,4,4,4') parser.add_argument("--moe-tp-extend-ep", action='store_true', help="use tp group to extend experts parallism instead of sharding weight tensor of experts in tp group") parser.add_argument('--mla-mm-split', action='store_true', default=False, diff --git a/mindspeed_llm/tasks/checkpoint/convert.py b/mindspeed_llm/tasks/checkpoint/convert.py new file mode 100644 index 000000000..1893dceb3 --- /dev/null +++ b/mindspeed_llm/tasks/checkpoint/convert.py @@ -0,0 +1,152 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import abc +import logging as logger +import os +from collections import defaultdict +import numpy as np +import torch +from .model_builder import MegatronModel, HuggingFaceModel + + +logger.basicConfig(format="") +logger.getLogger().setLevel(logger.INFO) + + +class Convert(abc.ABC): + + def __init__(self, args): + self.load_model = None + self.save_model = None + self.model_type_hf = args.model_type_hf + + # parallel train arguments + self.tp_size = args.target_tensor_parallel_size + self.pp_size = args.target_pipeline_parallel_size + self.ep_size = args.target_expert_parallel_size + self.num_layer_list = args.num_layer_list + self.noop_layers = args.noop_layers + self.vpp_stage = args.num_layers_per_virtual_pipeline_stage + + # lora arguments + self.qlora_nf4 = args.qlora_nf4 + + # features arguments + self.moe_grouped_gemm = args.moe_grouped_gemm + self.moe_tp_extend_ep = args.moe_tp_extend_ep + self.mla_mm_split = args.mla_mm_split + self.schedules_method = args.schedules_method + self.first_k_dense_replace = args.first_k_dense_replace + self.mtp_num_layers = args.mtp_num_layers + + # model arguments + self.num_layers = None + + + @staticmethod + def mg_path_process(mg_path): + """megatron model path""" + iter_mg_path = os.path.join(mg_path, "iter_0000001") + if not os.path.exists(mg_path): + os.makedirs(mg_path, exist_ok=True) + with open(os.path.join(mg_path, "latest_checkpointed_iteration.txt"), 'w') as f: + f.write("1") + return iter_mg_path + + + def generate_mg_weights_dir(self, tp_rank, pp_rank, ep_rank): + """Generate the megatron weight directory.""" + if self.ep_size == 1 and self.pp_size == 1: + prefix = f"mp_rank_{tp_rank:02}" + elif self.ep_size == 1: + prefix = f"mp_rank_{tp_rank:02}_{pp_rank:03}" + elif self.pp_size == 1: + prefix = f"mp_rank_{tp_rank:02}_{ep_rank:03}" + else: + prefix = f"mp_rank_{tp_rank:02}_{pp_rank:03}_{ep_rank:03}" + return prefix + + + def generate_pp_local_layer_idx(self): + """generate each pp local layer index""" + pp_local_layer_idx = defaultdict() + + for pp_rank in range(self.pp_size): + if self.num_layer_list is not None: + layer_list = list(map(int, self.num_layer_list.split(','))) + pp_local_layer_idx[pp_rank] = [i for i in range(layer_list[pp_rank])] + else: + pp_local_layer_idx[pp_rank] = [i for i in range(self.num_layers // self.pp_size)] + + if self.noop_layers is not None: + noop_list = list(map(int, self.noop_layers.split(","))) + num_layers_each_pp = self.num_layers // self.pp_size + for num_noop_layers in noop_list: + pp_idx = num_noop_layers // num_layers_each_pp + local_noop_idx = num_noop_layers % num_layers_each_pp + pp_local_layer_idx[pp_idx].remove(local_noop_idx) + + return pp_local_layer_idx + + + def generate_vpp_local_layer_idx(self): + vpp_local_layer_idx = defaultdict() + for pp_rank in range(self.pp_size): + vpp_local_layer_idx[pp_rank] = defaultdict() + + for pp_rank in range(self.pp_size): + for vpp_rank in range(self.vpp_size): + vpp_local_layer_idx[pp_rank][vpp_rank] = [i for i in range(self.vpp_stage)] + + if self.noop_layers is not None: + noop_list = list(map(int, self.noop_layers.split(","))) + num_layers_each_pp = self.num_layers // self.pp_size + + if self.schedules_method == 'dualpipev': + # calc pp rank, vpp rank and local idx of noop layer + for noop_layer in noop_list: + # e.g. pp2 noop5 [0 1 6 7 | 2 3 4 5] -> layer5: pp1 vpp1 local_idx1 + # layer5 and layer2 are symmetrical, so they are in the same pp_rank. + # all layer are divided into two parts. layer5 is in last part. so vpp_rank=1 + if noop_layer >= self.num_layers // 2: + mapping_layer = -(noop_layer - self.num_layers + 1) + vpp_idx = 1 + pp_idx = mapping_layer // ((self.num_layers // 2) // self.pp_size) + local_noop_idx = self.vpp_stage - 1 - (mapping_layer - pp_idx * self.vpp_stage) + else: + vpp_idx = 0 + pp_idx = noop_layer // ((self.num_layers // 2) // self.pp_size) + local_noop_idx = noop_layer - pp_idx * self.vpp_stage + vpp_local_layer_idx[pp_idx][vpp_idx].remove(local_noop_idx) + else: + for num_noop_layer in noop_list: + pp_idx = num_noop_layer % (self.pp_size * self.vpp_stage) // self.vpp_stage + vpp_idx = num_noop_layer // self.vpp_stage // self.pp_size + local_noop_idx = num_noop_layer % num_layers_each_pp % self.vpp_stage + vpp_local_layer_idx[pp_idx][vpp_idx].remove(local_noop_idx) + + return vpp_local_layer_idx + + @abc.abstractmethod + def set_model_preprocess(self, weights_dict, mg_model): + """Embedding layer process""" + pass + + @abc.abstractmethod + def set_model_postprocess(self, weights_dict, mg_model): + """Final norm & LM Head process""" + pass + + @abc.abstractmethod + def set_model_layer_norm(self, hf_layer_idx, local_layer_idx, weights_dict, mg_model, mtp_layer_flag=False): + """Layernorm process""" + pass + + @abc.abstractmethod + def set_model_layer_attn(self, hf_layer, local_layer_idx, weights_dict, mg_model, mtp_layer_flag=False): + """Attention layer process""" + pass + + @abc.abstractmethod + def set_model_layer_mlp(self, hf_layer_idx, local_layer_idx, weights_dict, mg_model, mtp_layer_flag=False): + """MLP layer process""" + pass \ No newline at end of file diff --git a/mindspeed_llm/tasks/checkpoint/convert_hf2mg.py b/mindspeed_llm/tasks/checkpoint/convert_hf2mg.py new file mode 100644 index 000000000..640c21aaa --- /dev/null +++ b/mindspeed_llm/tasks/checkpoint/convert_hf2mg.py @@ -0,0 +1,796 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import logging as logger +import os +from collections import defaultdict +import numpy as np +import torch +from .model_builder import MegatronModel, HuggingFaceModel +from .convert import Convert + +logger.basicConfig(format="") +logger.getLogger().setLevel(logger.INFO) + + +class Hf2MgConvert(Convert): + + def __init__(self, args): + super().__init__(args) + self.load_model = HuggingFaceModel(args) + self.save_model = MegatronModel(args) + + self.load_dir = args.load_dir + self.save_dir = self.mg_path_process(args.save_dir) + + # model arguments + self.num_layers = self.load_model.num_layers + len(eval(self.noop_layers)) + + if self.vpp_stage is None: + self.pprank_layer_idxs = defaultdict() + self.get_pprank_hf_layeridxs() + else: + self.vpp_size = self.num_layers // self.pp_size // self.vpp_stage + self.vpprank_layer_idxs = defaultdict(dict) + self.get_vpprank_hf_layeridxs() + self._valid_parameter() + + def _valid_parameter(self): + if self.schedules_method == 'dualpipev': + if self.tp_size > 1 and not self.moe_tp_extend_ep: + raise ValueError("When dualpipe is enabled, moe-tp-extend-ep should be used at the same time.") + + if self.num_layer_list is None: + if self.num_layers % self.pp_size != 0: + raise ValueError('number of layers should be divisible by the pipeline parallel size') + + if self.vpp_stage is not None: + if self.num_layers % self.pp_size % self.vpp_stage != 0: + raise ValueError('number of pp_stage should bu divisible by the vpp_stage') + else: + layer_list = list(map(int, self.num_layer_list.split(','))) + if self.vpp_stage is not None: + raise ValueError('num_layer_list and vpp cannot be configured at the same time') + if len(layer_list) != self.pp_size: + raise ValueError('number of layer_list should be equal to pipeline parallel size') + if sum(layer_list) != self.num_layers: + raise ValueError('sum of layer_list should be equal to num_layers') + if self.noop_layers is not None: + raise ValueError('num_layer_list and noop_layers cannot be configured at the same time') + + def get_pprank_hf_layeridxs(self) -> None: + """pp_rank -> hf layer map""" + num_noop_layers = 0 if self.noop_layers is None else len(list(map(int, self.noop_layers.split(",")))) + num_real_layers = self.num_layers - num_noop_layers + num_layer_list_ = [i for i in range(num_real_layers)] + + # Specifies the number of dense layers. + if getattr(self, "first_k_dense_replace", None): + """ + Support custom first_k_dense_replace, + but it cannot exceed the number of dense layers in the open source model weights. + """ + if self.first_k_dense_replace != self.load_model.first_k_dense_replace: + logger.warning("The number of custom dense layers is inconsistent with the number of open-source dense layers,\ + so the training is meaningless.") + + if self.first_k_dense_replace <= self.load_model.first_k_dense_replace: + num_moe_layers = num_real_layers - self.first_k_dense_replace + num_layer_list_ = [i for i in range(self.first_k_dense_replace)] + \ + [i + self.load_model.first_k_dense_replace for i in range(num_moe_layers)] + else: + raise ValueError( + "first_k_dense_replace must be less than or equal to the number of dense layers in the open source model") + + if self.num_layer_list is None: + layers_each_pp = [self.num_layers // self.pp_size] * self.pp_size + if self.noop_layers is not None: + for layer in list(map(int, self.noop_layers.split(","))): + cur_pp_rank = layer // (self.num_layers // self.pp_size) + layers_each_pp[cur_pp_rank] -= 1 + else: + layers_each_pp = list(map(int, self.num_layer_list.split(','))) + + for pp_rank in range(self.pp_size): + self.pprank_layer_idxs[pp_rank] = [num_layer_list_.pop(0) for _ in range(layers_each_pp[pp_rank])] + + # mtp layer + if self.mtp_num_layers: + nextn_layer_list = [self.load_model.num_layers + i for i in range(self.mtp_num_layers)] + self.pprank_layer_idxs[self.pp_size - 1].extend(nextn_layer_list) + + def get_vpprank_hf_layeridxs(self) -> None: + """vpp_rank -> hf layer map""" + num_noop_layers = 0 if self.noop_layers is None else len(list(map(int, self.noop_layers.split(",")))) + num_real_layers = self.num_layers - num_noop_layers + num_layer_list_ = [i for i in range(num_real_layers)] + + # Specifies the number of dense layers. + if getattr(self, "first_k_dense_replace", None): + """ + Support custom first_k_dense_replace, + but it cannot exceed the number of dense layers in the open source model weights. + """ + if self.first_k_dense_replace != self.load_model.first_k_dense_replace: + logger.warning("The number of custom dense layers is inconsistent with the number of open-source dense layers,\ + so the training is meaningless.") + + if self.first_k_dense_replace <= self.load_model.first_k_dense_replace: + num_moe_layers = num_real_layers - self.first_k_dense_replace + num_layer_list_ = [i for i in range(self.first_k_dense_replace)] + \ + [i + self.load_model.first_k_dense_replace for i in range(num_moe_layers)] + else: + raise ValueError( + "first_k_dense_replace must be less than or equal to the number of dense layers in the open source model") + + if self.schedules_method == 'dualpipev': + noop_layers_list = None if not self.noop_layers else np.array( + sorted(list(map(int, self.noop_layers.split(","))))) + min_noop_layer = None if not self.noop_layers else noop_layers_list[0] + + dualpipe_layer_list = [] + layers_each_pp = self.num_layers // self.pp_size + layer_pop_num = layers_each_pp // 2 + all_layer_list = [i for i in range(self.num_layers)] + # dualpipe_layer_list example + # pp2: [0 1 2 3 4 5 6 7] -> [0 1 6 7 | 2 3 4 5] + # pp4: [0 1 2 3 4 5 6 7] -> [0 7 | 1 6 | 2 5 | 3 4] + while all_layer_list: + dualpipe_layer_list.extend(all_layer_list[:layer_pop_num]) + dualpipe_layer_list.extend(all_layer_list[-layer_pop_num:]) + all_layer_list = all_layer_list[layer_pop_num:-layer_pop_num] + + # calc pp idx and vpp idx of each hf layer + pp_rank, vpp_rank = 0, 0 + each_pp_layer = self.num_layers // self.pp_size + for idx, layer in enumerate(dualpipe_layer_list): + if vpp_rank not in self.vpprank_layer_idxs[pp_rank]: + self.vpprank_layer_idxs[pp_rank][vpp_rank] = [] + + if not self.noop_layers: + self.vpprank_layer_idxs[pp_rank][vpp_rank].append(layer) + else: + # ignore noop layer + if layer in noop_layers_list: + if (idx + 1) % self.vpp_stage == 0: + vpp_rank += 1 + if (idx + 1) % each_pp_layer == 0: + pp_rank += 1 + vpp_rank = 0 + continue + if layer < min_noop_layer: + self.vpprank_layer_idxs[pp_rank][vpp_rank].append(layer) + if layer > min_noop_layer: + # remove noop layer index + before_nums = sum(noop_layers_list < layer) + self.vpprank_layer_idxs[pp_rank][vpp_rank].append(layer - before_nums) + + # update vpp_rank + if (idx + 1) % self.vpp_stage == 0: + vpp_rank += 1 + # update pp_rank, reset vpp_rank + if (idx + 1) % each_pp_layer == 0: + pp_rank += 1 + vpp_rank = 0 + else: + if self.vpp_stage is not None: + layers_each_vpp = [[self.vpp_stage] * self.vpp_size for _ in range(self.pp_size)] + # examples: num_layers8,pp2,vpp_stage2 [[0 1, 4 5], [2 3, 6 7]] + # no noop layer --> layers_each_vpp:[[2,2], [2,2]] + # noop4,5 --> layers_each_vpp:[[2,0], [2,2]] + if self.noop_layers is not None: + for layer in list(map(int, self.noop_layers.split(","))): + vpp_idx = layer // self.vpp_stage // self.pp_size + pp_idx = layer % (self.pp_size * self.vpp_stage) // self.vpp_stage + layers_each_vpp[pp_idx][vpp_idx] -= 1 + + for vpp_rank in range(self.vpp_size): + for pp_rank in range(self.pp_size): + self.vpprank_layer_idxs[pp_rank][vpp_rank] = [num_layer_list_.pop(0) for _ in + range(layers_each_vpp[pp_rank][vpp_rank])] + + if self.mtp_num_layers: + nextn_layer_list = [self.mtp_layer_number + i for i in range(self.mtp_num_layers)] + # for dualpipe, mtp layer in pp0vpp1 + mtp_pp_rank = 0 if self.schedules_method == 'dualpipev' else self.pp_size - 1 + self.vpprank_layer_idxs[mtp_pp_rank][self.vpp_size - 1].extend(nextn_layer_list) + + def load_matched_hf_weight(self, pp_rank, vpp_rank=None): + """Read the safetensors file corresponding to the layer of pp_rank.""" + if vpp_rank is None: + layer_list = self.pprank_layer_idxs[pp_rank] + else: + layer_list = self.vpprank_layer_idxs[pp_rank][vpp_rank].copy() + if pp_rank == self.pp_size - 1 and self.mtp_num_layers: + nextn_layer_list = [self.load_model.num_layers + i for i in range(self.mtp_num_layers)] + layer_list.extend(nextn_layer_list) + layer_files_map_dict = self.load_model.get_layer_files_map() + + st_filename_list = [] + for layer in layer_list: + # start with model.layers.[layer_number], contains the mtp layer. + st_filename_list.extend(list(layer_files_map_dict[layer])) + + hf_weight_key = self.load_model.get_weight() + if pp_rank == 0: + st_filename_list.extend(list(layer_files_map_dict[hf_weight_key["embedding_word_embeddings"]])) + if self.schedules_method == 'dualpipev': + st_filename_list.extend(list(layer_files_map_dict[hf_weight_key["output_layer"]])) + st_filename_list.extend(list(layer_files_map_dict[hf_weight_key["final_layernorm"]])) + + if pp_rank == self.pp_size - 1 and self.schedules_method is None: + st_filename_list.extend(list(layer_files_map_dict[hf_weight_key["final_layernorm"]])) + st_filename_list.extend(list(layer_files_map_dict[hf_weight_key["output_layer"]])) + + st_filename_list = list(set(st_filename_list)) + st_filename_list.sort() + + all_pp_weights = {} + for filename in st_filename_list: + cur_weights = self.load_model.load_hf_model(os.path.join(self.load_dir, filename)) + all_pp_weights.update(cur_weights) + + return all_pp_weights + + def set_model_preprocess(self, hf_weight, mg_weight): + """Embedding layer process""" + hf_weight_key = self.load_model.get_weight() + mg_weight_key = self.save_model.get_weight() + emb_weight = hf_weight.pop(hf_weight_key["embedding_word_embeddings"]) + + for ep_rank in range(self.ep_size): + emb_weight_lst = torch.chunk(emb_weight, self.tp_size, dim=0) + for tp_rank in range(self.tp_size): + mg_weight[ep_rank][tp_rank][mg_weight_key["embedding_word_embeddings"]] = emb_weight_lst[ + tp_rank].clone() + + def set_model_postprocess(self, hf_weight, mg_weight): + """Final norm & LM Head process""" + hf_weight_key = self.load_model.get_weight() + mg_weight_key = self.save_model.get_weight() + final_norm = hf_weight.pop(hf_weight_key["final_layernorm"]) + lm_head = hf_weight.pop(hf_weight_key["output_layer"]) + + for ep_rank in range(self.ep_size): + lm_head_lst = torch.chunk(lm_head, self.tp_size, dim=0) + for tp_rank in range(self.tp_size): + if self.mtp_num_layers: + mg_weight[ep_rank][tp_rank][mg_weight_key["mtp_final_layernorms"]] = final_norm.clone() + else: + mg_weight[ep_rank][tp_rank][mg_weight_key["final_layernorm"]] = final_norm.clone() + mg_weight[ep_rank][tp_rank][mg_weight_key["output_layer"]] = lm_head_lst[tp_rank].clone() + if self.qlora_nf4: + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, mg_weight_key["output_layer"], + lm_head_lst[tp_rank].clone()) + + def set_mtp_preprocess(self, hf_layer_idx, mtp_layer_idx, hf_weight, mg_weight): + """MTP layer preprocess""" + hf_weight_key = self.load_model.get_weight(hf_layer_idx) + mg_weight_key = self.save_model.get_weight(mtp_layer_idx) + enorm_weight = hf_weight.pop(hf_weight_key["mtp_layers_enorm"]) + hnorm_weight = hf_weight.pop(hf_weight_key["mtp_layers_hnorm"]) + eh_proj_weight = hf_weight.pop(hf_weight_key["mtp_layers_eh_proj"]) + emb_weight = hf_weight.pop(hf_weight_key["mtp_layers_embed_tokens"]) + + for ep_rank in range(self.ep_size): + eh_proj_lst = torch.chunk(eh_proj_weight, self.tp_size, dim=0) + emb_lst = torch.chunk(emb_weight, self.tp_size, dim=0) + for tp_rank in range(self.tp_size): + mg_weight[ep_rank][tp_rank][mg_weight_key["mtp_layers_enorm"]] = enorm_weight.clone() + mg_weight[ep_rank][tp_rank][mg_weight_key["mtp_layers_hnorm"]] = hnorm_weight.clone() + mg_weight[ep_rank][tp_rank][mg_weight_key["mtp_layers_eh_proj"]] = eh_proj_lst[tp_rank].clone() + if self.qlora_nf4: + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, mg_weight_key["mtp_layers_eh_proj"], + eh_proj_lst[tp_rank].clone()) + + if self.pp_size > 1: + mg_weight[ep_rank][tp_rank][mg_weight_key["mtp_layers_embed_tokens"]] = \ + emb_lst[tp_rank].clone() + + def set_mtp_postprocess(self, hf_layer_idx, mtp_layer_idx, hf_weight, mg_weight): + """MTP layer postprocess""" + hf_weight_key = self.load_model.get_weight(hf_layer_idx) + mg_weight_key = self.save_model.get_weight(mtp_layer_idx) + mtp_norm_weight = hf_weight.pop(hf_weight_key["mtp_layers_shared_head_norm"]) + + for ep_rank in range(self.ep_size): + for tp_rank in range(self.tp_size): + mg_weight[ep_rank][tp_rank][ + mg_weight_key["mtp_post_norm"]] = mtp_norm_weight.clone() + + def set_model_layer_norm(self, hf_layer_idx, local_layer_idx, hf_weight, mg_weight, mtp_layer_flag=False): + """Layernorm process""" + hf_weight_key = self.load_model.get_weight(hf_layer_idx) + mg_weight_key = self.save_model.get_weight(local_layer_idx) + input_norm = hf_weight.pop(hf_weight_key["layers_input_layernorm"]) + post_attn_norm = hf_weight.pop(hf_weight_key["layers_self_attention_pre_mlp_layernorm"]) + + # Weight key of the mtp layer is different from that of the transformers layer. + if mtp_layer_flag: + input_norm_key = mg_weight_key["mtp_layers_input_layernorm"] + post_norm_key = mg_weight_key["mtp_layers_self_attention_post_attention_layernorm"] + else: + input_norm_key = mg_weight_key["layers_input_layernorm"] + post_norm_key = mg_weight_key["layers_self_attention_post_attention_layernorm"] if hasattr(self.load_model, + "post_attention") \ + else mg_weight_key["layers_self_attention_pre_mlp_layernorm"] + + for ep_rank in range(self.ep_size): + for tp_rank in range(self.tp_size): + mg_weight[ep_rank][tp_rank][input_norm_key] = input_norm.clone() + mg_weight[ep_rank][tp_rank][post_norm_key] = post_attn_norm.clone() + + def set_model_layer_attn(self, hf_layer_idx, local_layer_idx, hf_weight, mg_weight, mtp_layer_flag=False): + """Attention layer process""" + + hf_weight_key = self.load_model.get_weight(hf_layer_idx) + mg_weight_key = self.save_model.get_weight(local_layer_idx) + + def _generate_mla_attn_layers_key(mtp_flag): + if mtp_flag: + qkv_key = mg_weight_key["mtp_layers_self_attention_linear_qkv"] + dense_key = mg_weight_key["mtp_layers_self_attention_linear_proj"] + q_b_key = mg_weight_key["mtp_layers_self_attention_linear_qb"] + kv_b_key = mg_weight_key["mtp_layers_self_attention_linear_kvb"] + q_layernorm_key = mg_weight_key["mtp_layers_self_attention_q_layernorm"] + kv_layernorm_key = mg_weight_key["mtp_layers_self_attention_k_layernorm"] + else: + qkv_key = mg_weight_key["layers_self_attention_linear_qkv"] + dense_key = mg_weight_key["layers_self_attention_linear_proj"] + q_b_key = mg_weight_key["layers_self_attention_linear_qb"] + kv_b_key = mg_weight_key["layers_self_attention_linear_kvb"] + q_layernorm_key = mg_weight_key["layers_self_attention_q_layernorm"] + kv_layernorm_key = mg_weight_key["layers_self_attention_k_layernorm"] + + return qkv_key, dense_key, q_layernorm_key, kv_layernorm_key, q_b_key, kv_b_key + + def _generate_attn_mm_split_key(mtp_flag): + if mtp_flag: + qk_nope_key = mg_weight_key["mtp_layers_self_attention_linear_qk_nope"] + qk_rope_key = mg_weight_key["mtp_layers_self_attention_linear_qk_rope"] + kv_nope_key = mg_weight_key["mtp_layers_self_attention_linear_kv_nope"] + linear_v_key = mg_weight_key["mtp_layers_self_attention_linear_v"] + else: + qk_nope_key = mg_weight_key["layers_self_attention_linear_qk_nope"] + qk_rope_key = mg_weight_key["layers_self_attention_linear_qk_rope"] + kv_nope_key = mg_weight_key["layers_self_attention_linear_kv_nope"] + linear_v_key = mg_weight_key["layers_self_attention_linear_v"] + + return qk_nope_key, qk_rope_key, kv_nope_key, linear_v_key + + def _generate_attn_layers_key(): + qkv_key = mg_weight_key["layers_self_attention_linear_qkv"] + dense_key = mg_weight_key["layers_self_attention_linear_proj"] + q_layernorm_key = mg_weight_key["layers_self_attention_q_layernorm"] + k_layernorm_key = mg_weight_key["layers_self_attention_k_layernorm"] + return qkv_key, dense_key, q_layernorm_key, k_layernorm_key + + nh = self.load_model.num_attention_heads + ng = self.load_model.num_key_value_heads + dim = self.load_model.kv_channels if hasattr(self.load_model, "kv_channels") \ + else self.load_model.hidden_size // self.load_model.num_attention_heads + + if not nh % ng == 0: + raise ValueError("nh % ng should equal 0") + + def qkv_concatenate_weight(qkv): + return torch.cat([ + qkv[0].reshape((ng, dim * nh // ng, -1)), + qkv[1].reshape((ng, dim, -1)), + qkv[2].reshape((ng, dim, -1)), + ], dim=1).reshape((-1, self.load_model.hidden_size)) + + if self.load_model.qkv_type == "pack_mla": + qkv_key, dense_key, q_layernorm_key, kv_layernorm_key, q_b_key, kv_b_key = _generate_mla_attn_layers_key( + mtp_layer_flag) + hf_q_proj = hf_weight.pop(hf_weight_key["layers_self_attention_linear_q_proj"]) + hf_kv_proj = hf_weight.pop(hf_weight_key["layers_self_attention_linear_kv_proj"]) + qkv_weight = torch.cat([hf_q_proj.reshape((-1, self.load_model.hidden_size)), + hf_kv_proj.reshape((-1, self.load_model.hidden_size))], dim=0) + + dense_weight = hf_weight.pop(hf_weight_key["layers_self_attention_linear_proj"]) + dense_lst = torch.chunk(dense_weight, self.tp_size, dim=1) + q_b_proj = hf_weight.pop(hf_weight_key["layers_self_attention_linear_qb"]) + kv_b_proj = hf_weight.pop(hf_weight_key["layers_self_attention_linear_kvb"]) + q_layernorm = hf_weight.pop(hf_weight_key["layers_self_attention_q_layernorm"]) + k_layernorm = hf_weight.pop(hf_weight_key["layers_self_attention_k_layernorm"]) + + if self.mla_mm_split: + q_b_proj = q_b_proj.reshape(self.load_model.num_attention_heads, + (self.load_model.qk_nope_head_dim + self.load_model.qk_rope_head_dim), + -1) + kv_b_proj = kv_b_proj.reshape(self.load_model.num_attention_heads, + (self.load_model.qk_nope_head_dim + self.load_model.v_head_dim), -1) + qk_nope, qk_rope = torch.split(q_b_proj, + [self.load_model.qk_nope_head_dim, self.load_model.qk_rope_head_dim], + dim=1) + kv_nope, linear_v = torch.split(kv_b_proj, + [self.load_model.qk_nope_head_dim, self.load_model.v_head_dim], dim=1) + qk_nope = qk_nope.reshape(self.load_model.num_attention_heads * self.load_model.qk_nope_head_dim, -1) + qk_rope = qk_rope.reshape(self.load_model.num_attention_heads * self.load_model.qk_rope_head_dim, -1) + kv_nope = kv_nope.reshape(self.load_model.num_attention_heads * self.load_model.qk_nope_head_dim, -1) + linear_v = linear_v.reshape(self.load_model.num_attention_heads * self.load_model.v_head_dim, -1) + + qk_nope_lst = torch.chunk(qk_nope, self.tp_size, dim=0) + qk_rope_lst = torch.chunk(qk_rope, self.tp_size, dim=0) + kv_nope_lst = torch.chunk(kv_nope, self.tp_size, dim=0) + linear_v_lst = torch.chunk(linear_v, self.tp_size, dim=0) + else: + linear_qb_lst = torch.chunk(q_b_proj, self.tp_size, dim=0) + linear_kvb_lst = torch.chunk(kv_b_proj, self.tp_size, dim=0) + + elif self.load_model.qkv_type == 'unpack': + hf_q_proj = hf_weight.pop(hf_weight_key["layers_self_attention_linear_q_proj"]) + hf_k_proj = hf_weight.pop(hf_weight_key["layers_self_attention_linear_k_proj"]) + hf_v_proj = hf_weight.pop(hf_weight_key["layers_self_attention_linear_v_proj"]) + dense_weight = hf_weight.pop(hf_weight_key["layers_self_attention_linear_proj"]) + dense_lst = torch.chunk(dense_weight, self.tp_size, dim=1) + + qkv_weight = [hf_q_proj, hf_k_proj, hf_v_proj] + qkv_weight = qkv_concatenate_weight(qkv_weight) + qkv_weight_lst = torch.chunk(qkv_weight, self.tp_size, dim=0) + if self.load_model.qk_layernorm: + q_layernorm = hf_weight.pop(hf_weight_key["layers_self_attention_q_layernorm"]) + k_layernorm = hf_weight.pop(hf_weight_key["layers_self_attention_k_layernorm"]) + + for ep_rank in range(self.ep_size): + for tp_rank in range(self.tp_size): + if hasattr(self.load_model, "multi_head_latent_attention"): + qkv_key, dense_key, q_layernorm_key, k_layernorm_key, q_b_key, kv_b_key = _generate_mla_attn_layers_key( + mtp_layer_flag) + mg_weight[ep_rank][tp_rank][qkv_key] = qkv_weight.clone() + mg_weight[ep_rank][tp_rank][dense_key] = dense_lst[tp_rank].clone() + mg_weight[ep_rank][tp_rank][q_layernorm_key] = q_layernorm.clone() + mg_weight[ep_rank][tp_rank][k_layernorm_key] = k_layernorm.clone() + + if self.qlora_nf4: + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, qkv_key, qkv_weight.clone()) + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, dense_key, dense_lst[tp_rank].clone()) + if self.mla_mm_split: + qk_nope_key, qk_rope_key, kv_nope_key, linear_v_key = _generate_attn_mm_split_key( + mtp_layer_flag) + mg_weight[ep_rank][tp_rank][qk_nope_key] = qk_nope_lst[tp_rank].clone() + mg_weight[ep_rank][tp_rank][qk_rope_key] = qk_rope_lst[tp_rank].clone() + mg_weight[ep_rank][tp_rank][kv_nope_key] = kv_nope_lst[tp_rank].clone() + mg_weight[ep_rank][tp_rank][linear_v_key] = linear_v_lst[tp_rank].clone() + if self.qlora_nf4: + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, qk_nope_key, qk_nope_lst[tp_rank].clone()) + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, qk_rope_key, qk_rope_lst[tp_rank].clone()) + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, kv_nope_key, kv_nope_lst[tp_rank].clone()) + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, linear_v_key, + linear_v_lst[tp_rank].clone()) + else: + mg_weight[ep_rank][tp_rank][q_b_key] = linear_qb_lst[tp_rank].clone() + mg_weight[ep_rank][tp_rank][kv_b_key] = linear_kvb_lst[tp_rank].clone() + if self.qlora_nf4: + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, q_b_key, linear_qb_lst[tp_rank].clone()) + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, kv_b_key, linear_kvb_lst[tp_rank].clone()) + + else: + qkv_key, dense_key, q_layernorm_key, k_layernorm_key = _generate_attn_layers_key() + mg_weight[ep_rank][tp_rank][qkv_key] = qkv_weight_lst[tp_rank].clone() + mg_weight[ep_rank][tp_rank][dense_key] = dense_lst[tp_rank].clone() + if self.load_model.qk_layernorm: + mg_weight[ep_rank][tp_rank][q_layernorm_key] = q_layernorm.clone() + mg_weight[ep_rank][tp_rank][k_layernorm_key] = k_layernorm.clone() + + def get_first_k_dense_replace(self): + if getattr(self, "first_k_dense_replace", None) is None: + num_experts = (getattr(self.load_model, 'num_experts', None) or + getattr(self.load_model, 'num_local_experts', None)) + if num_experts is None: + return self.load_model.num_layers + else: + return 0 + else: + return self.first_k_dense_replace + + def set_model_layer_mlp(self, hf_layer_idx, local_layer_idx, hf_weight, mg_weight, mtp_layer_flag=False): + """MLP layer process""" + + hf_weight_key = self.load_model.get_weight(hf_layer_idx) + first_k_dense_replace = self.get_first_k_dense_replace() + if hf_layer_idx >= first_k_dense_replace: + # moe layer & mtp layer + mlp_router_weight = hf_weight.pop(hf_weight_key["layers_mlp_router"]) + mlp_router_weight = mlp_router_weight[:self.load_model.num_experts, :] + + if hasattr(self.load_model, "router_bias"): + mlp_router_bias = hf_weight.pop(hf_weight_key["layers_mlp_router_bias"]) + mlp_router_bias = mlp_router_bias[:self.load_model.num_experts] + + if hasattr(self.load_model, "n_shared_experts"): + shared_gate_proj = hf_weight.pop(hf_weight_key["layers_mlp_shared_experts_gate_proj"]) + shared_up_proj = hf_weight.pop(hf_weight_key["layers_mlp_shared_experts_up_proj"]) + shared_fc2_weight = hf_weight.pop(hf_weight_key["layers_mlp_shared_experts_linear_fc2"]) + + experts_linear_fc1_list = [] + experts_linear_fc2_list = [] + + def _generate_moe_layer_key(mtp_flag): + if mtp_flag: + router_key = mg_weight_key["mtp_layers_mlp_router"] + router_bias_key = mg_weight_key["mtp_layers_mlp_router_bias"] + shared_fc1_key = mg_weight_key["mtp_layers_mlp_shared_experts_linear_fc1"] + shared_fc2_key = mg_weight_key["mtp_layers_mlp_shared_experts_linear_fc2"] + experts_weight1_key = mg_weight_key["mtp_layers_mlp_experts_weight1"] + experts_weight2_key = mg_weight_key["mtp_layers_mlp_experts_weight2"] + else: + router_key = mg_weight_key["layers_mlp_router"] + router_bias_key = mg_weight_key["layers_mlp_router_bias"] + shared_fc1_key = mg_weight_key["layers_mlp_shared_experts_linear_fc1"] + shared_fc2_key = mg_weight_key["layers_mlp_shared_experts_linear_fc2"] + experts_weight1_key = mg_weight_key["layers_mlp_experts_weight1"] + experts_weight2_key = mg_weight_key["layers_mlp_experts_weight2"] + return router_key, router_bias_key, shared_fc1_key, shared_fc2_key, experts_weight1_key, experts_weight2_key + + for expert_idx in range(self.load_model.num_experts): + hf_weight_key = self.load_model.get_weight(hf_layer_idx, expert_idx) + + if hasattr(self.load_model, "n_shared_experts"): + shared_l0_W = torch.chunk(shared_gate_proj, self.tp_size, dim=0) + shared_l0_V = torch.chunk(shared_up_proj, self.tp_size, dim=0) + shared_l0_lst = [torch.cat(weights, dim=0) for weights in zip(shared_l0_W, shared_l0_V)] + shared_l1_lst = torch.chunk(shared_fc2_weight, self.tp_size, dim=1) + + gate_proj = hf_weight.pop(hf_weight_key["layers_mlp_experts_gate_proj"]) + up_proj = hf_weight.pop(hf_weight_key["layers_mlp_experts_up_proj"]) + + expert_tp_size = self.tp_size + if self.moe_tp_extend_ep: + expert_tp_size = 1 + + gate_w_list = torch.chunk(gate_proj, expert_tp_size, dim=0) + up_w_list = torch.chunk(up_proj, expert_tp_size, dim=0) + fc1_weight = torch.cat([torch.cat(weights, dim=0) for weights in zip(gate_w_list, up_w_list)], dim=0) + + fc2_weight = hf_weight.pop(hf_weight_key["layers_mlp_experts_linear_fc2"]) + + experts_linear_fc1_list.append(fc1_weight.t()) + experts_linear_fc2_list.append(fc2_weight.t()) + + for ep_rank in range(self.ep_size): + + # generate weights key + mg_weight_key = self.save_model.get_weight(local_layer_idx, ep_rank) + router_key, router_bias_key, shared_fc1_key, shared_fc2_key, experts_weight1_key, experts_weight2_key \ + = _generate_moe_layer_key(mtp_layer_flag) + + for tp_rank in range(self.tp_size): + mg_weight[ep_rank][tp_rank][router_key] = mlp_router_weight.clone() + if hasattr(self.load_model, "router_bias"): + mg_weight[ep_rank][tp_rank][router_bias_key] = mlp_router_bias.clone() + if hasattr(self.load_model, "n_shared_experts"): + mg_weight[ep_rank][tp_rank][shared_fc1_key] = shared_l0_lst[tp_rank].clone() + mg_weight[ep_rank][tp_rank][shared_fc2_key] = shared_l1_lst[tp_rank].clone() + + if self.qlora_nf4: + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, shared_fc1_key, + shared_l0_lst[tp_rank].clone()) + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, shared_fc2_key, + shared_l1_lst[tp_rank].clone()) + + if self.moe_grouped_gemm: + gemm_fc1 = torch.cat(experts_linear_fc1_list).view(self.load_model.hidden_size, -1) + gemm_fc2 = torch.cat(experts_linear_fc2_list).view(-1, self.load_model.hidden_size) + if self.moe_tp_extend_ep: + gemm_fc1_ep = torch.chunk( + gemm_fc1.view(self.load_model.num_experts, self.load_model.hidden_size, -1), + self.ep_size * self.tp_size, dim=0) + gemm_fc2_ep = torch.chunk( + gemm_fc2.view(self.load_model.num_experts, -1, self.load_model.hidden_size), + self.ep_size * self.tp_size, dim=0) + else: + gemm_fc1_ep = torch.chunk( + gemm_fc1.view(self.load_model.num_experts, self.load_model.hidden_size, -1), self.ep_size, + dim=0) + gemm_fc2_ep = torch.chunk( + gemm_fc2.view(self.load_model.num_experts, -1, self.load_model.hidden_size), self.ep_size, + dim=0) + + for ep_rank in range(self.ep_size): + mg_weight_key = self.save_model.get_weight(local_layer_idx, ep_rank) + router_key, router_bias_key, shared_fc1_key, shared_fc2_key, experts_weight1_key, experts_weight2_key = _generate_moe_layer_key( + mtp_layer_flag) + if not self.moe_tp_extend_ep: + gemm_fc1_ep_tp = torch.chunk(gemm_fc1_ep[ep_rank], self.tp_size, dim=2) + gemm_fc2_ep_tp = torch.chunk(gemm_fc2_ep[ep_rank], self.tp_size, dim=1) + for tp_rank in range(self.tp_size): + if self.moe_tp_extend_ep: + mg_weight[ep_rank][tp_rank][experts_weight1_key] = gemm_fc1_ep[ + ep_rank * self.tp_size + tp_rank].reshape(self.load_model.hidden_size, -1).clone() + mg_weight[ep_rank][tp_rank][experts_weight2_key] = gemm_fc2_ep[ + ep_rank * self.tp_size + tp_rank].reshape(-1, self.load_model.hidden_size).clone() + if self.qlora_nf4: + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, experts_weight1_key, + gemm_fc1_ep[ep_rank * self.tp_size + tp_rank].reshape( + self.load_model.hidden_size, -1).clone()) + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, experts_weight2_key, + gemm_fc2_ep[ep_rank * self.tp_size + tp_rank].reshape(-1, + self.load_model.hidden_size).clone()) + else: + mg_weight[ep_rank][tp_rank][experts_weight1_key] = gemm_fc1_ep_tp[tp_rank].reshape( + self.load_model.hidden_size, -1).clone() + mg_weight[ep_rank][tp_rank][experts_weight2_key] = gemm_fc2_ep_tp[tp_rank].reshape( + -1, self.load_model.hidden_size).clone() + if self.qlora_nf4: + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, experts_weight1_key, + gemm_fc1_ep_tp[tp_rank].reshape(self.load_model.hidden_size, + -1).clone()) + self.qlora_nf4_quant(mg_weight, ep_rank, tp_rank, experts_weight2_key, + gemm_fc2_ep_tp[tp_rank].reshape(-1, + self.load_model.hidden_size).clone()) + else: + num_local_experts = self.load_model.num_experts // self.ep_size + for ep_rank in range(self.ep_size): + mg_weight_key = self.save_model.get_weight(local_layer_idx, ep_rank) + for local_experts_idx in range(num_local_experts): + local_fc1_key = mg_weight_key["layers_mlp_experts_linear_fc1"] + local_fc2_key = mg_weight_key["layers_mlp_experts_linear_fc2"] + if mtp_layer_flag: + local_fc1_key = mg_weight_key["mtp_layers_mlp_experts_linear_fc1"] + local_fc2_key = mg_weight_key["mtp_layers_mlp_experts_linear_fc2"] + + global_experts_idx = local_experts_idx + ep_rank * num_local_experts + local_fc1_weight = experts_linear_fc1_list[global_experts_idx].t() + local_fc2_weight = experts_linear_fc2_list[global_experts_idx].t() + + local_fc1_lst = torch.chunk(local_fc1_weight, self.tp_size, dim=0) + local_fc2_lst = torch.chunk(local_fc2_weight, self.tp_size, dim=1) + + for tp_rank in range(self.tp_size): + mg_model[ep_rank][tp_rank][local_fc1_key] = local_fc1_lst[tp_rank].clone() + mg_model[ep_rank][tp_rank][local_fc2_key] = local_fc2_lst[tp_rank].clone() + if self.qlora_nf4: + self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, local_fc1_key, + local_fc1_lst[tp_rank].clone()) + self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, local_fc2_key, + local_fc2_lst[tp_rank].clone()) + else: + mg_weight_key = self.save_model.get_weight(local_layer_idx) + # dense layer + gate_proj = hf_weight.pop(hf_weight_key["layers_mlp_gate_proj"]) + up_proj = hf_weight.pop(hf_weight_key["layers_mlp_up_proj"]) + + linear_fc1_weight = torch.cat([gate_proj, up_proj], dim=0) + linear_fc2_weight = hf_weight.pop(hf_weight_key["layers_mlp_linear_fc2"]) + + for ep_rank in range(self.ep_size): + gate, up = torch.chunk(linear_fc1_weight, 2, dim=0) + + mlp_l0_weight_W = torch.chunk(gate, self.tp_size, dim=0) + mlp_l0_weight_V = torch.chunk(up, self.tp_size, dim=0) + mlp_l0_weight = [torch.cat(weights, dim=0) for weights in zip(mlp_l0_weight_W, mlp_l0_weight_V)] + + mlp_l1_weight = torch.chunk(linear_fc2_weight, self.tp_size, dim=1) + for tp_rank in range(self.tp_size): + mg_weight[ep_rank][tp_rank][mg_weight_key["layers_mlp_linear_fc1"]] = \ + mlp_l0_weight[tp_rank].clone() + mg_weight[ep_rank][tp_rank][mg_weight_key["layers_mlp_linear_fc2"]] = \ + mlp_l1_weight[tp_rank].clone() + if self.qlora_nf4: + self.qlora_nf4_quant(hf_weight, ep_rank, tp_rank, + mg_weight_key["layers_mlp_linear_fc1"], + mlp_l0_weight[tp_rank].clone()) + self.qlora_nf4_quant(hf_weight, ep_rank, tp_rank, + mg_weight_key["layers_mlp_linear_fc2"], + mlp_l1_weight[tp_rank].clone()) + + def run(self): + """save magetron format checkpoint""" + pp_local_layer_idx = self.generate_pp_local_layer_idx() + + if self.vpp_stage is None: + for pp_rank in range(self.pp_size): + mg_weight = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + + hf_pp_weights = self.load_matched_hf_weight(pp_rank) + if pp_rank == 0: + self.set_model_preprocess(hf_pp_weights, mg_weight) + + layer_list = self.pprank_layer_idxs[pp_rank] + + if self.mtp_num_layers and pp_rank == self.pp_size - 1: + layer_list.sort() + mtp_layer_list = [layer_list.pop() for _ in range(self.mtp_num_layers)] + + local_mtp_idx = 0 + for mtp_layer in mtp_layer_list: + logger.info(f"Converting the weights of mtp layer {mtp_layer}.") + self.set_mtp_preprocess(mtp_layer, local_mtp_idx, hf_pp_weights, mg_weight) + self.set_model_layer_norm(mtp_layer, local_mtp_idx, hf_pp_weights, mg_weight, + mtp_layer_flag=True) + self.set_model_layer_attn(mtp_layer, local_mtp_idx, hf_pp_weights, mg_weight, + mtp_layer_flag=True) + self.set_model_layer_mlp(mtp_layer, local_mtp_idx, hf_pp_weights, mg_weight, + mtp_layer_flag=True) + self.set_mtp_postprocess(mtp_layer, local_mtp_idx, hf_pp_weights, mg_weight) + local_mtp_idx += 1 + + local_idx = 0 + cur_pp_local_idx = pp_local_layer_idx[pp_rank] + + for hf_layer in layer_list: + logger.info(f"Converting the weights of layer {hf_layer}.") + local_layer_idx = cur_pp_local_idx[local_idx] + self.set_model_layer_norm(hf_layer, local_layer_idx, hf_pp_weights, mg_weight) + self.set_model_layer_attn(hf_layer, local_layer_idx, hf_pp_weights, mg_weight) + self.set_model_layer_mlp(hf_layer, local_layer_idx, hf_pp_weights, mg_weight) + local_idx += 1 + + if pp_rank == self.pp_size - 1: + self.set_model_postprocess(hf_pp_weights, mg_weight) + + for ep_rank in range(self.ep_size): + for tp_rank in range(self.tp_size): + save_prefix = self.generate_mg_weights_dir(tp_rank=tp_rank, pp_rank=pp_rank, ep_rank=ep_rank) + parallel_save_path = os.path.join(self.save_dir, save_prefix) + os.makedirs(parallel_save_path, exist_ok=True) + save_file_name = os.path.join(parallel_save_path, "model_optim_rng.pt") + logger.info(f"Saving to {save_file_name}") + + torch.save({"model": mg_weight[ep_rank][tp_rank], "checkpoint_version": 3.0, "iteration": 1}, + save_file_name, pickle_protocol=4, _use_new_zipfile_serialization=True) + else: + vpp_local_layer_idx = self.generate_vpp_local_layer_idx() + for pp_rank in range(self.pp_size): + mg_weight = defaultdict() + for vpp_rank in range(self.vpp_size): + hf_pp_weight = self.load_matched_hf_weight(pp_rank, vpp_rank) + mg_weight[vpp_rank] = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + vpp_list = self.vpprank_layer_idxs[pp_rank][vpp_rank] + + if pp_rank == 0 and vpp_rank == 0: + self.set_model_preprocess(hf_pp_weight, mg_weight[vpp_rank]) + + if self.schedules_method == 'dualpipev' and pp_rank == 0 and vpp_rank == self.vpp_size - 1: + self.set_model_postprocess(hf_pp_weight, mg_weight[vpp_rank]) + + if self.mtp_num_layers: + dualpipe_mtp_flag = self.schedules_method == 'dualpipev' and pp_rank == 0 and vpp_rank == self.vpp_size - 1 + norm_mtp_flag = self.schedules_method != 'dualpipev' and pp_rank == self.pp_size - 1 and vpp_rank == self.vpp_size - 1 + + if dualpipe_mtp_flag or norm_mtp_flag: + vpp_list.sort() + mtp_layer_list = [vpp_list.pop() for _ in range(self.mtp_num_layers)] + local_mtp_idx = 0 + for mtp_layer in mtp_layer_list: + logger.info(f"Converting the weights of mtp layer {mtp_layer}.") + self.set_mtp_preprocess(mtp_layer, local_mtp_idx, hf_pp_weight, mg_weight[vpp_rank]) + self.set_model_layer_norm(mtp_layer, local_mtp_idx, hf_pp_weight, mg_weight[vpp_rank], + mtp_layer_flag=True) + self.set_model_layer_attn(mtp_layer, local_mtp_idx, hf_pp_weight, mg_weight[vpp_rank], + mtp_layer_flag=True) + self.set_model_layer_mlp(mtp_layer, local_mtp_idx, hf_pp_weight, mg_weight[vpp_rank], + mtp_layer_flag=True) + self.set_mtp_postprocess(mtp_layer, local_mtp_idx, hf_pp_weight, mg_weight[vpp_rank]) + local_mtp_idx += 1 + + local_idx = 0 + cur_vpp_local_idx = vpp_local_layer_idx[pp_rank][vpp_rank] + + for hf_layer in vpp_list: + logger.info(f"Converting the weights of layer {hf_layer}.") + local_layer_idx = cur_vpp_local_idx[local_idx] + self.set_model_layer_norm(hf_layer, local_layer_idx, hf_pp_weight, mg_weight[vpp_rank]) + self.set_model_layer_attn(hf_layer, local_layer_idx, hf_pp_weight, mg_weight[vpp_rank]) + self.set_model_layer_mlp(hf_layer, local_layer_idx, hf_pp_weight, mg_weight[vpp_rank]) + local_idx += 1 + + if self.schedules_method != 'dualpipev' and pp_rank == self.pp_size - 1 and vpp_rank == self.vpp_size - 1: + self.set_model_postprocess(hf_pp_weight, mg_weight[vpp_rank]) + + for ep_rank in range(self.ep_size): + for tp_rank in range(self.tp_size): + save_prefix = self.generate_mg_weights_dir(tp_rank=tp_rank, pp_rank=pp_rank, ep_rank=ep_rank) + parallel_save_path = os.path.join(self.save_dir, save_prefix) + os.makedirs(parallel_save_path, exist_ok=True) + save_file_name = os.path.join(parallel_save_path, "model_optim_rng.pt") + logger.info(f"Saving to {save_file_name}") + model_dict = {"checkpoint_version": 3.0, "iteration": 1} + + for vpp_rank in range(self.vpp_size): + model_key = f"model{vpp_rank}" + model_dict[model_key] = mg_weight[vpp_rank][ep_rank][tp_rank] + + torch.save(model_dict, save_file_name, pickle_protocol=4, _use_new_zipfile_serialization=True) + + logger.info("Done!") \ No newline at end of file diff --git a/mindspeed_llm/tasks/checkpoint/model_builder.py b/mindspeed_llm/tasks/checkpoint/model_builder.py new file mode 100644 index 000000000..b6c956865 --- /dev/null +++ b/mindspeed_llm/tasks/checkpoint/model_builder.py @@ -0,0 +1,305 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import abc +import json +import logging as logger +import os +from collections import defaultdict +import re +import safetensors +import safetensors.torch + +logger.basicConfig(format="") +logger.getLogger().setLevel(logger.INFO) + + +class Model(abc.ABC): + def __init__(self): + self.module_mapping = None + + @abc.abstractmethod + def get_weight(self): + pass + + @abc.abstractmethod + def get_bias(self): + pass + + @abc.abstractmethod + def get_module_mapping(self): + pass + + +class HuggingFaceModel(Model): + def __init__(self, args): + super(HuggingFaceModel, self).__init__() + self.model_cfg = self.read_model_cfg() + self.model_type_hf = args.model_type_hf + self.hf_path = args.load_dir if args.load_model_type == 'hf' else args.save_dir + self.load_hf_args() + self.module_mapping = self.get_module_mapping() + + + def load_hf_args(self): + """ + Load config.json, apply key mappings and config values from model_cfg, + and set them as instance attributes. + """ + hf_args_path = os.path.join(self.hf_path, "config.json") + with open(hf_args_path) as f: + hf_args = json.load(f) + + config_key_mapping = self.model_cfg.get(self.model_type_hf).get('config_hf_key_mapping') + config_value = self.model_cfg.get(self.model_type_hf).get('config_set_value') + for key_target in config_key_mapping: + key_hf = config_key_mapping[key_target] + if key_hf in hf_args: + setattr(self, key_target, hf_args[key_hf]) + else: + setattr(self, key_hf, hf_args[key_hf]) + + for key_target, value in config_value.items(): + setattr(self, key_target, value) + + + def get_module_mapping(self): + return self.model_cfg.get(self.model_type_hf).get('model_hf_key_mapping') + + @staticmethod + def read_model_cfg(): + def merge_configs(base_config, specific_config): + merged_config = base_config.copy() + for key, value in specific_config.items(): + if isinstance(value, dict) and key in merged_config: + merged_config[key] = merge_configs(merged_config[key], value) + else: + merged_config[key] = value + return merged_config + + current_directory = os.path.dirname(os.path.abspath(__file__)) + cfg_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(current_directory))), + "configs/checkpoint/model_cfg.json") + with open(cfg_dir, 'r') as file: + config = json.load(file) + final_configs = {} + + for model_name, model_config in config["model_mappings"].items(): + if "__base__" in model_config: + base_model_name = model_config["__base__"] + base_config = config["model_mappings"][base_model_name] + specific_config = model_config.copy() + specific_config.pop("__base__", None) + final_config = merge_configs(base_config, specific_config) + else: + final_config = model_config + final_configs[model_name] = final_config + + return final_configs + + + def get_weight(self, layer_idx=0, expert_idx=0): + module_key = {} + for key, value in self.module_mapping.items(): + value = re.sub(r'\[layer_idx\]', f'.{layer_idx}', value) + value = re.sub(r'\[expert_idx\]', f'.{expert_idx}', value) + module_key[key] = value + ".weight" if ("weight" not in value and "bias" not in value) else value + return module_key + + + def get_bias(self, layer_idx=0, expert_idx=0): + module_key = {} + for key, value in self.module_mapping.items(): + value = re.sub(r'\[layer_idx\]', f'.{layer_idx}', value) + value = re.sub(r'\[expert_idx\]', f'.{expert_idx}', value) + module_key[key] = value + ".bias" + return module_key + + + def get_layer_files_map(self): + """layer -> safetensors file map""" + layer_map_dict = defaultdict(set) + weights_map_file_path = os.path.join(self.hf_path, "model.safetensors.index.json") + + with open(weights_map_file_path) as f: + weights_map = json.load(f) + weights_map = weights_map["weight_map"] + + for key, value in weights_map.items(): + if key.startswith("model.layers."): + layer_name = int(key.split('model.layers.')[1].split('.')[0]) + layer_map_dict[layer_name].add(value) + else: + layer_map_dict[key].add(value) + return layer_map_dict + + @staticmethod + def load_hf_model(file_path): + """Load safetensors file""" + logger.info(f"Loading the checkpoint from {file_path}.") + return safetensors.torch.load_file(file_path) + + +class MegatronModel(Model): + def __init__(self, args): + super(MegatronModel, self).__init__() + self.shared_expert_gate = args.shared_expert_gate + self.save_lora_to_hf = False + + self.mla_mm_split = args.mla_mm_split + self.mtp_num_layers = args.mtp_num_layers + self.module_mapping = self.get_module_mapping() + + + def get_weight(self, layer_idx=0, expert_idx=0): + module_key = {} + for key, value in self.module_mapping.items(): + value = re.sub(r'\[layer_idx\]', f'.{layer_idx}', value) + value = re.sub(r'\[expert_idx\]', f'.{expert_idx}', value) + module_key[key] = value + ".weight" if ("weight" not in value and "bias" not in value) else value + return module_key + + + def get_bias(self, layer_idx=0, expert_idx=0): + module_key = {} + for key, value in self.module_mapping.items(): + value = re.sub(r'\[layer_idx\]', f'.{layer_idx}', value) + value = re.sub(r'\[expert_idx\]', f'.{expert_idx}', value) + module_key[key] = value + ".bias" + return module_key + + + def get_module_mapping(self): + module_layer = "decoder.layers[layer_idx]." + module_layer_mtp = "mtp.layers[layer_idx].transformer_layer." + module_mapping = { + "embedding": "embedding", + "embedding_word_embeddings": "embedding.word_embeddings", + "embedding_word_embeddings_norm": "embedding.word_embeddings.norm", + "embedding_position_embeddings": "embedding.position_embeddings", + "model": "module", + "layers_input_layernorm": module_layer + "input_layernorm", + "layers": "decoder.layers", + "layers_self_attention_linear_proj": module_layer + "self_attention.linear_proj", + "layers_self_attention_linear_qkv": module_layer + "self_attention.linear_qkv", + "layers_self_attention_q_layernorm": module_layer + "self_attention.q_layernorm", + "layers_self_attention_k_layernorm": module_layer + "self_attention.k_layernorm", + "layers_self_attention_post_attention_layernorm": module_layer + "post_attn_norm", + "layers_self_attention_pre_mlp_layernorm": module_layer + "pre_mlp_layernorm", + "layers_mlp_linear_fc1": module_layer + "mlp.linear_fc1", + "layers_mlp_linear_fc2": module_layer + "mlp.linear_fc2", + "layers_self_attention_post_mlp_layernorm": module_layer + "post_mlp_layernorm", + "final_layernorm": "decoder.final_layernorm", + "output_layer": "output_layer", + "rm_head": "rm_head" + } + + module_mapping["layers_mlp_router"] = module_layer + "mlp.router" + module_mapping["layers_mlp_router_bias"] = module_layer + "mlp.router.expert_bias" + module_mapping[ + "layers_mlp_experts_linear_fc1"] = module_layer + "mlp.experts.local_experts[expert_idx].linear_fc1" + module_mapping[ + "layers_mlp_experts_linear_fc2"] = module_layer + "mlp.experts.local_experts[expert_idx].linear_fc2" + + # MLA + module_mapping["layers_self_attention_linear_qb"] = module_layer + "self_attention.linear_qb" + module_mapping["layers_self_attention_linear_kvb"] = module_layer + "self_attention.linear_kvb" + + # shared experts + module_mapping[ + "layers_mlp_shared_experts_linear_fc1"] = module_layer + "mlp.shared_experts.linear_fc1" + module_mapping[ + "layers_mlp_shared_experts_linear_fc2"] = module_layer + "mlp.shared_experts.linear_fc2" + + # shared experts gate + if self.shared_expert_gate: + module_mapping["layers_mlp_shared_expert_gate"] = module_layer + "mlp.shared_expert_gate" + + # moe grouped gemm + module_mapping[ + "layers_mlp_experts_weight1"] = module_layer + "mlp.experts.weight1" + module_mapping[ + "layers_mlp_experts_weight2"] = module_layer + "mlp.experts.weight2" + + if self.mtp_num_layers: + module_mapping[ + "mtp_layers_enorm"] = "mtp.layers[layer_idx].enorm" + module_mapping[ + "mtp_layers_hnorm"] = "mtp.layers[layer_idx].hnorm" + module_mapping[ + "mtp_layers_eh_proj"] = "mtp.layers[layer_idx].eh_proj" + module_mapping[ + "mtp_layers_embed_tokens"] = "embedding.word_embeddings" + module_mapping[ + "mtp_layers_input_layernorm"] = module_layer_mtp + "input_layernorm" + module_mapping[ + "mtp_layers_self_attention_post_attention_layernorm"] = module_layer_mtp + "pre_mlp_layernorm" + module_mapping[ + "mtp_layers_self_attention_linear_proj"] = module_layer_mtp + "self_attention.linear_proj" + module_mapping[ + "mtp_layers_self_attention_linear_qkv"] = module_layer_mtp + "self_attention.linear_qkv" + module_mapping[ + "mtp_layers_self_attention_linear_qb"] = module_layer_mtp + "self_attention.linear_qb" + module_mapping[ + "mtp_layers_self_attention_linear_kvb"] = module_layer_mtp + "self_attention.linear_kvb" + module_mapping[ + "mtp_layers_self_attention_q_layernorm"] = module_layer_mtp + "self_attention.q_layernorm" + module_mapping[ + "mtp_layers_self_attention_k_layernorm"] = module_layer_mtp + "self_attention.k_layernorm" + module_mapping[ + "mtp_layers_mlp_router"] = module_layer_mtp + "mlp.router" + module_mapping[ + "mtp_layers_mlp_router_bias"] = module_layer_mtp + "mlp.router.expert_bias" + module_mapping[ + "mtp_layers_mlp_experts_weight1"] = module_layer_mtp + "mlp.experts.weight1" + module_mapping[ + "mtp_layers_mlp_experts_weight2"] = module_layer_mtp + "mlp.experts.weight2" + module_mapping[ + "mtp_layers_mlp_shared_experts_linear_fc1"] = module_layer_mtp + "mlp.shared_experts.linear_fc1" + module_mapping[ + "mtp_layers_mlp_shared_experts_linear_fc2"] = module_layer_mtp + "mlp.shared_experts.linear_fc2" + module_mapping[ + "mtp_layers_mlp_experts_linear_fc1"] = module_layer_mtp + "mlp.experts.local_experts[expert_idx].linear_fc1" + module_mapping[ + "mtp_layers_mlp_experts_linear_fc2"] = module_layer_mtp + "mlp.experts.local_experts[expert_idx].linear_fc2" + module_mapping[ + "mtp_post_norm"] = "mtp.final_layernorms[layer_idx]" + module_mapping[ + "mtp_final_layernorms"] = "final_layernorm" + + + if self.mla_mm_split: + module_mapping[ + "mtp_layers_self_attention_linear_qk_nope"] = module_layer_mtp + "self_attention.linear_qk_nope" + module_mapping[ + "mtp_layers_self_attention_linear_qk_rope"] = module_layer_mtp + "self_attention.linear_qk_rope" + module_mapping[ + "mtp_layers_self_attention_linear_kv_nope"] = module_layer_mtp + "self_attention.linear_kv_nope" + module_mapping[ + "mtp_layers_self_attention_linear_v"] = module_layer_mtp + "self_attention.linear_v" + + # lora + if self.save_lora_to_hf: + module_mapping[ + "layers_self_attention_linear_qkv_lora_A_default"] = module_layer + "self_attention.linear_qkv.lora_A.default" + module_mapping[ + "layers_self_attention_linear_qkv_lora_B_default"] = module_layer + "self_attention.linear_qkv.lora_B.default" + module_mapping[ + "layers_self_attention_linear_proj_lora_A_default"] = module_layer + "self_attention.linear_proj.lora_A.default" + module_mapping[ + "layers_self_attention_linear_proj_lora_B_default"] = module_layer + "self_attention.linear_proj.lora_B.default" + module_mapping[ + "layers_mlp_linear_fc1_lora_A_default"] = module_layer + "mlp.linear_fc1.lora_A.default" + module_mapping[ + "layers_mlp_linear_fc1_lora_B_default"] = module_layer + ".mlp.linear_fc1.lora_B.default" + module_mapping[ + "layers_mlp_linear_fc2_lora_A_default"] = module_layer + "mlp.linear_fc2.lora_A.default" + module_mapping[ + "layers_mlp_linear_fc2_lora_B_default"] = module_layer + "mlp.linear_fc2.lora_B.default" + module_mapping[ + "layers_mlp_experts_linear_fc1_lora_A_default"] = module_layer + "mlp.experts.local_experts[expert_idx].linear_fc1.lora_A.default" + module_mapping[ + "layers_mlp_experts_linear_fc1_lora_B_default"] = module_layer + ".mlp.experts.local_experts[expert_idx].linear_fc1.lora_B.default" + module_mapping[ + "layers_mlp_experts_linear_fc2_lora_A_default"] = module_layer + "mlp.experts.local_experts[expert_idx].linear_fc2.lora_A.default" + module_mapping[ + "layers_mlp_experts_linear_fc2_lora_B_default"] = module_layer + "mlp.experts.local_experts[expert_idx].linear_fc2.lora_B.default" + return module_mapping -- Gitee