From 8719461e65f9ad1b9689d5cd33ac3cf772d2bdc2 Mon Sep 17 00:00:00 2001 From: freyafu Date: Mon, 7 Apr 2025 14:54:52 +0800 Subject: [PATCH] router fp32 fix --- mindspeed_llm/core/transformer/moe/router.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/mindspeed_llm/core/transformer/moe/router.py b/mindspeed_llm/core/transformer/moe/router.py index 25d37f620..9ddfcd810 100644 --- a/mindspeed_llm/core/transformer/moe/router.py +++ b/mindspeed_llm/core/transformer/moe/router.py @@ -445,14 +445,18 @@ def apply_seq_aux_loss(self, activation, logits, topk_idx): def topk_router_gating_func(self, input: torch.Tensor): _args = get_args() if _args.router_gating_in_fp32: - def to_fp32(_input, weight): - return _input.type(torch.float32), weight.type(torch.float32) - self.fp32_checkpoint_manager = CheckpointWithoutOutput() - input, weight = self.fp32_checkpoint_manager.checkpoint(to_fp32, False, input, self.weight) - logits = torch.nn.functional.linear(input, weight) - self.fp32_checkpoint_manager.discard_output() - if logits.requires_grad: - logits.register_hook(self.fp32_checkpoint_manager.recompute) + if not self.weight.requires_grad: + # if weight is not requires_grad like lora finetune, can not autograd for weight in checkpoint_manager + logits = F.linear(input.type(torch.float32), self.weight.type(torch.float32)) + else: + def to_fp32(_input, weight): + return _input.type(torch.float32), weight.type(torch.float32) + self.fp32_checkpoint_manager = CheckpointWithoutOutput() + input, weight = self.fp32_checkpoint_manager.checkpoint(to_fp32, False, input, self.weight) + logits = torch.nn.functional.linear(input, weight) + self.fp32_checkpoint_manager.discard_output() + if logits.requires_grad: + logits.register_hook(self.fp32_checkpoint_manager.recompute) else: logits = F.linear(input, self.weight) -- Gitee