diff --git a/torch_npu/meta/_meta_registrations.py b/torch_npu/meta/_meta_registrations.py index 8f8caac40a753364a1e890c25f045921c38309bf..1488db4f3dfa0266f3ec3c641dc8c3e5a44d7966 100644 --- a/torch_npu/meta/_meta_registrations.py +++ b/torch_npu/meta/_meta_registrations.py @@ -635,9 +635,11 @@ def npu_group_quant_meta(x, scale, group_index, *, offset=None, dst_dtype=None): return torch.empty_like(x, dtype=torch.int8) elif dst_dtype == torch.quint4x2: dim_num = x.dim() + if dim_num == 0: + raise RuntimeError("Input x can't be scalar" + ops_error(ErrCode.PARAM)) if x.size(dim_num - 1) % 8: - raise RuntimeError("If dst_dtype is quint4x2, last dim must be divisible by 8" + - ops_error(ErrCode.NOT_SUPPORT)) + raise RuntimeError("If dst_dtype is quint4x2, x last dim must be divisible by 8" + + ops_error(ErrCode.PARAM)) output_shape = [] for dim in range(dim_num - 1): output_shape.append(x.size(dim))