diff --git a/torch_npu/meta/_meta_registrations.py b/torch_npu/meta/_meta_registrations.py index 8f8caac40a753364a1e890c25f045921c38309bf..3da24df98b3339e1b1abe9c9cf60b637dc2c7993 100644 --- a/torch_npu/meta/_meta_registrations.py +++ b/torch_npu/meta/_meta_registrations.py @@ -694,6 +694,32 @@ def npu_apply_rotary_pos_emb_meta(query, key, cos, sin, layout=1): return (torch.empty_like(query, dtype=query.dtype), torch.empty_like(key, dtype=key.dtype)) +@impl(m, "npu_moe_finalize_routing_v2_backward") +def npu_moe_finalize_routing_v2_backward_meta(grad_y, expanded_row_idx, expanded_x=None, scales=None, expert_idx=None, + bias=None, drop_pad_mode=0, active_num=0, expert_num=0, + expert_capacity=0): + grad_expanded_x_dim0 = expanded_row_idx.size(0) + if drop_pad_mode == 0 and active_num > 0 and active_num < grad_expanded_x_dim0: + grad_expanded_x_dim0 = active_num + elif drop_pad_mode == 1: + if expert_num == 0 or expert_capacity == 0: + raise RuntimeError("When drop_pad_mode is 1, expert_num and expert_capacity should be greater than 0.") + grad_expanded_x_dim0 = expert_num * expert_capacity + grad_expanded_x_dim1 = grad_y.size(1) + grad_expanded_x = torch.empty([grad_expanded_x_dim0, grad_expanded_x_dim1], dtype=grad_y.dtype, device='meta') + + grad_scales_dim0 = grad_y.size(0) + grad_scales_dim1 = 1 + if grad_y.size(0) != 0: + if expanded_row_idx.size(0) % grad_y.size(0) == 0: + grad_scales_dim1 = expanded_row_idx.size(0) / grad_y.size(0) + else: + raise RuntimeError("The first dim of expanded_row_idx should be a multiple of the first dim of grad_y.") + grad_scales = torch.empty([grad_scales_dim0, grad_scales_dim1], dtype=grad_y.dtype, device='meta') + + return (torch.empty_like(grad_expanded_x), torch.empty_like(grad_scales)) + + @impl(m, "npu_quant_conv2d") def npu_quant_conv2d(input_, weight, scale, strides, pads, dilations, groups=1, offset_x=0, round_mode='rint', output_dtype=None, bias=None, offset=None):