diff --git a/mindspeed_llm/tasks/checkpoint/models.py b/mindspeed_llm/tasks/checkpoint/models.py index 7562de5b07f3411f52e520aa0496707315e766ec..83bf29a06d4c0821995627a0e4f6357b080ed5a2 100644 --- a/mindspeed_llm/tasks/checkpoint/models.py +++ b/mindspeed_llm/tasks/checkpoint/models.py @@ -649,8 +649,12 @@ class HuggingfaceModel(ModelBase): fc1_weight = self.get_layers_mlp_experts_linear_fc1_weight(**kwargs) if getattr(args, "swiglu", None): gate_w, up_w = torch.chunk(fc1_weight, 2, dim=0) - gate_w_list = torch.chunk(gate_w, getattr(self.args_cmd, 'target_tensor_parallel_size', 1), dim=0) - up_w_list = torch.chunk(up_w, getattr(self.args_cmd, 'target_tensor_parallel_size', 1), dim=0) + if args.moe_tp_extend_ep: + gate_w_list = torch.chunk(gate_w, 1, dim=0) + up_w_list = torch.chunk(up_w, 1, dim=0) + else: + gate_w_list = torch.chunk(gate_w, getattr(self.args_cmd, 'target_tensor_parallel_size', 1), dim=0) + up_w_list = torch.chunk(up_w, getattr(self.args_cmd, 'target_tensor_parallel_size', 1), dim=0) fc1_weight = torch.cat([torch.cat(weights, dim=0) for weights in zip(gate_w_list, up_w_list)], dim=0) experts_linear_fc1_list.append(fc1_weight.t().view(-1)) return torch.cat(experts_linear_fc1_list).view(args.hidden_size, -1)