diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 5df8ebd955274c2fcb143072cdcf27002bce3919..cdbf78affdd5cfd27d1ed94b9c98c8b917f7be86 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -326,6 +326,18 @@ def npu_mm_all_reduce_base_forward(x1, x2, hcom, reduce_op='sum', bias=None, ant return x1.new_empty(tuple(dim_list)) +@impl(m, "npu_mm_all_reduce_add_rms_norm") +def npu_mm_all_reduce_add_rms_norm_forward(x1, x2, residual, gamma, hcom, reduce_op='sum', epsilon=1e-6, bias=None, + antiquant_scale=None, antiquant_offset=None, dequant_scale=None, + antiquant_group_size=0, comm_turn=0): + return (torch.empty_like(residual, dtype=residual.dtype), torch.empty_like(residual, dtype=residual.dtype)) + + +@impl(m, "npu_mm_all_reduce_add_rms_norm_") +def npu_inplace_mm_all_reduce_add_rms_norm_forward(x1, x2, residual, gamma, hcom, reduce_op='sum', epsilon=1e-6, bias=None, + antiquant_scale=None, antiquant_offset=None, dequant_scale=None, + antiquant_group_size=0, comm_turn=0): + return (torch.empty_like(residual, dtype=residual.dtype), torch.empty_like(residual, dtype=residual.dtype)) @impl(m, "npu_mm_reduce_scatter_base") def npu_mm_reduce_scatter_base_meta(self, x2, hcom, world_size, reduce_op='sum',