diff --git a/mindspeed_llm/tasks/checkpoint/loader_mg.py b/mindspeed_llm/tasks/checkpoint/loader_mg.py index 9303dc51ef459ad44abdc7c4ba9d94913b1d0c8a..5fe1ff093a33fd2beac336fd3dc7d292f75c00e9 100644 --- a/mindspeed_llm/tasks/checkpoint/loader_mg.py +++ b/mindspeed_llm/tasks/checkpoint/loader_mg.py @@ -47,6 +47,8 @@ def add_arguments(parser): help='Lora alpha.') group.add_argument('--moe-grouped-gemm', action='store_true', help='Usr moe grouped gemm.') + group.add_argument("--moe-tp-extend-ep", action='store_true', + help="use tp group to extend experts parallelism instead of sharding weight tensor of experts in tp group") group.add_argument('--load-from-legacy', action='store_true', help='Is loader legacy') group.add_argument('--spec', type=str, default=None, nargs='*', @@ -92,6 +94,7 @@ def build_metadata(args, margs): md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by md.embed_layernorm = margs.embed_layernorm md.moe_grouped_gemm = margs.moe_grouped_gemm + md.moe_tp_extend_ep = margs.moe_tp_extend_ep md.spec = margs.spec md.num_experts = getattr(margs, "num_experts", None) md.n_shared_experts = getattr(margs, "n_shared_experts", None) @@ -336,29 +339,62 @@ def get_message_layer_mlp(message, model, md=None, **kwargs): expert = message["mlp_moe"][f"expert {global_expert_idx}"] _get_message_layer_mlp(expert, model, md, is_moe_mlp=True, **kwargs) if margs.moe_grouped_gemm: - weight2_ep = [] - gate_tp_list = [[] for _ in range(margs.num_experts)] - up_tp_list = [[] for _ in range(margs.num_experts)] - weight1_list = [] - for ep in range(margs.expert_model_parallel_size): - for tp in range(margs.tensor_model_parallel_size): - ep_expert_list = torch.chunk(weight1[ep][tp].view(-1), num_experts_local) - for i in range(num_experts_local): - gate, up = torch.chunk(ep_expert_list[i].view(margs.hidden_size, -1).t(), 2, dim=0) - gate_tp_list[ep * num_experts_local + i].append(gate) - up_tp_list[ep * num_experts_local + i].append(up) - for expert_idx in range(margs.num_experts): - gate = torch.cat(gate_tp_list[expert_idx], dim=0) - up = torch.cat(up_tp_list[expert_idx], dim=0) - fc1_weight = torch.cat([gate, up], dim=0) - weight1_list.append(fc1_weight.t().view(-1)) - for ep in range(margs.expert_model_parallel_size): - weight2_ep.append(torch.cat( - [tp_weight2.view(num_experts_local, -1, margs.hidden_size) for tp_weight2 in weight2[ep]], - dim=1)) + if margs.moe_tp_extend_ep: + bucket_num = margs.expert_model_parallel_size * margs.tensor_model_parallel_size + bucket_expert_num = margs.num_experts // bucket_num + weight2_list = [] + gate_tp_list = [[] for _ in range(margs.num_experts)] + up_tp_list = [[] for _ in range(margs.num_experts)] + down_list = [[] for _ in range(margs.num_experts)] + weight1_list = [] + for ep_rank in range(margs.expert_model_parallel_size): + for tp_rank in range(margs.tensor_model_parallel_size): + cur_weight1_bucket = weight1[ep_rank][tp_rank] + cur_weight2_bucket = weight2[ep_rank][tp_rank] + cur_w1_list = torch.chunk(cur_weight1_bucket.view(-1), bucket_expert_num, dim=0) + cur_w2_list = torch.chunk(cur_weight2_bucket, bucket_expert_num, dim=0) + global_expert_idx = ep_rank * margs.tensor_model_parallel_size + tp_rank + for idx in range(bucket_expert_num): + local_w1 = cur_w1_list[idx].reshape(margs.hidden_size, -1) + local_w2 = cur_w2_list[idx].reshape(-1, margs.hidden_size) + # global expert idx + expert_idx = global_expert_idx * bucket_expert_num + idx + gate, up = torch.chunk(local_w1.view(margs.hidden_size, -1).t(), 2, dim=0) + gate_tp_list[expert_idx].append(gate) + up_tp_list[expert_idx].append(up) + down = local_w2 + down_list[expert_idx].append(down) + + for expert_idx in range(margs.num_experts): + gate = torch.cat(gate_tp_list[expert_idx], dim=0) + up = torch.cat(up_tp_list[expert_idx], dim=0) + fc1_weight = torch.cat([gate, up], dim=0) + weight1_list.append(fc1_weight.t().view(-1)) + weight2_list.append(torch.cat([tp_weight2 for tp_weight2 in down_list[expert_idx]], dim=1)) + + else: + weight2_list = [] + gate_tp_list = [[] for _ in range(margs.num_experts)] + up_tp_list = [[] for _ in range(margs.num_experts)] + weight1_list = [] + for ep_rank in range(margs.expert_model_parallel_size): + for tp_rank in range(margs.tensor_model_parallel_size): + ep_expert_list = torch.chunk(weight1[ep_rank][tp_rank].view(-1), num_experts_local) + for i in range(num_experts_local): + gate, up = torch.chunk(ep_expert_list[i].view(margs.hidden_size, -1).t(), 2, dim=0) + gate_tp_list[ep_rank * num_experts_local + i].append(gate) + up_tp_list[ep_rank * num_experts_local + i].append(up) + for expert_idx in range(margs.num_experts): + gate = torch.cat(gate_tp_list[expert_idx], dim=0) + up = torch.cat(up_tp_list[expert_idx], dim=0) + fc1_weight = torch.cat([gate, up], dim=0) + weight1_list.append(fc1_weight.t().view(-1)) + for ep_rank in range(margs.expert_model_parallel_size): + weight2_list.append(torch.cat( + [tp_weight2.view(num_experts_local, -1, margs.hidden_size) for tp_weight2 in weight2[ep_rank]], + dim=1)) message["mlp_moe"]["mlp experts weight1 module"] = torch.cat(weight1_list, dim=0).view(margs.hidden_size, -1) - message["mlp_moe"]["mlp experts weight2 module"] = torch.cat(weight2_ep, dim=0).view(margs.num_experts, - -1) + message["mlp_moe"]["mlp experts weight2 module"] = torch.cat(weight2_list, dim=0).view(margs.num_experts, -1) else: _get_message_layer_mlp(message, model, md, **kwargs) @@ -451,6 +487,7 @@ def _load_checkpoint(model_provider, queue, args): margs = model_mg.get_args() margs.add_output_layer_bias = getattr(args, "add_output_layer_bias", False) margs.moe_grouped_gemm = args.moe_grouped_gemm + margs.moe_tp_extend_ep = args.moe_tp_extend_ep margs.spec = args.spec md = build_metadata(args, margs) diff --git a/mindspeed_llm/tasks/checkpoint/models.py b/mindspeed_llm/tasks/checkpoint/models.py index f9b517a8dbd86fca1c4d4d4bec09449890f3209f..ead74608ab206380b0a4819274a70c0b23bd53fc 100644 --- a/mindspeed_llm/tasks/checkpoint/models.py +++ b/mindspeed_llm/tasks/checkpoint/models.py @@ -159,8 +159,12 @@ class ModelBase(abc.ABC): # insert noop layer mg_layer_list.insert(i, -1) for dst_layer_idx, src_layer_idx in enumerate(mg_layer_list): - if not self.is_noop_layer(src_layer_idx): - self.set_layer_state_base(src_model, src_layer_idx=src_layer_idx, dst_layer_idx=dst_layer_idx) + if self.args_cmd.save_model_type == "hf": + if not self.is_noop_layer(src_layer_idx): + self.set_layer_state_base(src_model, src_layer_idx=dst_layer_idx, dst_layer_idx=src_layer_idx) + else: + if not self.is_noop_layer(src_layer_idx): + self.set_layer_state_base(src_model, src_layer_idx=src_layer_idx, dst_layer_idx=dst_layer_idx) def set_preprocess_state(self, src_model): """Set embedding params.""" @@ -495,6 +499,7 @@ class HuggingfaceModel(ModelBase): self.args.add_dense_bias = self.args_cmd.add_dense_bias self.args.post_norm = self.args_cmd.post_norm self.args.save_lora_to_hf = self.args_cmd.save_lora_to_hf + self.args.noop_layers = self.args_cmd.noop_layers def get_modules_from_config(self, device_map="cpu", trust_remote_code=True): # Load Huggingface model.