diff --git a/mindspeed_llm/core/transformer/moe/router.py b/mindspeed_llm/core/transformer/moe/router.py index cb0be41703ee36416222876ab37ae0ab6eb7fca8..be87413165871594a49688daa8abe2f596e01e94 100644 --- a/mindspeed_llm/core/transformer/moe/router.py +++ b/mindspeed_llm/core/transformer/moe/router.py @@ -555,6 +555,7 @@ def topk_router_forward(self, input: torch.Tensor): """ args = get_args() self.hidden = input.shape[-1] + _maintain_float32_expert_bias(self) # add input_jitter to distinguish whether to use if args.input_jitter: @@ -565,3 +566,15 @@ def topk_router_forward(self, input: torch.Tensor): scores, indices = self.routing(logits) return scores, indices + + +def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. + + When using bf16/fp16, the expert bias gets converted to lower precision in Float16Module. + We keep it in float32 to avoid routing errors when updating the expert_bias. + """ + if hasattr(self, 'expert_bias') and self.expert_bias is not None: + if self.expert_bias.dtype != torch.float32: + self.expert_bias.data = self.expert_bias.data.to(torch.float32)