diff --git a/test/contrib/test_linear_a8w8_quant.py b/test/contrib/test_linear_a8w8_quant.py index 27ff64d1e5023910868974e84f46b0496ddf5dbb..aa32865b33b509dcea80dbbd8337f6dec838a3fe 100644 --- a/test/contrib/test_linear_a8w8_quant.py +++ b/test/contrib/test_linear_a8w8_quant.py @@ -10,7 +10,7 @@ DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] class TestLinearA8W8Quant(TestCase): def npu_linear_quant(self, in_features, out_features, x1, x2, scale): - model = LinearA8W8Quant(in_features, out_features, False) + model = LinearA8W8Quant(in_features, out_features, bias=False, pertoken_scale=False, offset=False) model = model.npu() model.weight.data = x2 model.scale.data = scale diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index f32f9cadfde953da9fcabbc9a861a033f7617ffd..b5fc5450fbe6ac5df2c5890f72e173a903505920 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1634,13 +1634,13 @@ class TestQuantMatmul(TestCase): scale = torch.randn(1, dtype=torch.float32).npu() offset = torch.randn(1, dtype=torch.float32).npu() bias = torch.randint(-1, -1, (1, 1, 100), dtype=torch.int32).npu() - res = torch_npu.npu_quant_matmul(x1, x2, scale, offset, bias) + res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias) self.assertTrue(expect_ret.shape == res.shape) self.assertTrue(expect_ret.dtype == res.dtype) expect_ret_bf16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.bfloat16).npu() scale_bf16 = torch.randn(1, dtype=torch.bfloat16).npu() - res_bf16 = torch_npu.npu_quant_matmul(x1, x2, scale_bf16, None, bias, "bfloat16") + res_bf16 = torch_npu.npu_quant_matmul(x1, x2, scale_bf16, offset=None, bias=bias, output_dtype=torch.bfloat16) self.assertTrue(expect_ret_bf16.shape == res_bf16.shape) self.assertTrue(expect_ret_bf16.dtype == res_bf16.dtype) diff --git a/torch_npu/contrib/module/linear_a8w8_quant.py b/torch_npu/contrib/module/linear_a8w8_quant.py index 92dc193eb4b1528cccb236b918083fe31a30f592..9dd285d3edf56bccf96d49b9e2234980ad31c0b0 100644 --- a/torch_npu/contrib/module/linear_a8w8_quant.py +++ b/torch_npu/contrib/module/linear_a8w8_quant.py @@ -30,6 +30,7 @@ class LinearA8W8Quant(nn.Module): :math:`k = \frac{1}{\text{in\_features}}` scale: quant matmul calculation parameter offset: quant matmul calculation parameter + pertoken_scale: inverse quant matmul calculation parameter bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where @@ -52,9 +53,11 @@ class LinearA8W8Quant(nn.Module): weight: Tensor scale: Tensor offset: Tensor + pertoken_scale: Tensor + bias: Tensor - def __init__(self, in_features: int, out_features: int, bias: bool = True, offset: bool = False, - device=None, dtype=None, output_dtype=None) -> None: + def __init__(self, in_features: int, out_features: int, *, bias: bool = True, offset: bool = False, + pertoken_scale: bool = False, device=None, dtype=None, output_dtype=None) -> None: super(LinearA8W8Quant, self).__init__() self.in_features = in_features @@ -66,6 +69,12 @@ class LinearA8W8Quant(nn.Module): self.offset = Parameter(torch.empty(out_features, dtype=torch.float32), False) else: self.register_parameter('offset', None) + + if pertoken_scale: + self.pertoken_scale = Parameter(torch.empty(out_features, dtype=torch.float32), False) + else: + self.register_parameter('pertoken_scale', None) + if bias: self.bias = Parameter(torch.empty(out_features, dtype=torch.int32), False) else: @@ -75,7 +84,8 @@ class LinearA8W8Quant(nn.Module): scale_quant = self.scale first_last_dim = self.weight.dim() - 1 second_last_dim = self.weight.dim() - 2 - if self.scale.dtype == torch.float32: + if self.scale.dtype == torch.float32 and self.pertoken_scale is None: scale_quant = torch_npu.npu_trans_quant_param(self.scale, self.offset) + return torch_npu.npu_quant_matmul(linear_quant_input, self.weight.transpose(second_last_dim, first_last_dim), - scale_quant, self.offset, self.bias, self.output_dtype) + scale_quant, offset=self.offset, pertoken_scale=self.pertoken_scale, bias=self.bias, output_dtype=self.output_dtype) diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 0409f61d65bd9a1bf8ea374a1679d98326362551..1ba6c76d4ac6fcf7271783e4a89416090db8cf75 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -319,11 +319,12 @@ def bias_shape_check(x2, bias, batch_val): ) -def quant_matmul_shape_check(x1, x2, scale, offset): +def quant_matmul_shape_check(x1, x2, scale, offset, pertoken_scale): X_MAX_DIM = 6 X_MIN_DIM = 2 x1_dim_num = x1.dim() x2_dim_num = x2.dim() + x1_m_dim = x1.size(x1_dim_num - 2) x1_k_dim = x1.size(x1_dim_num - 1) x2_k_dim = x2.size(x2_dim_num - 2) x2_n_dim = x2.size(x2_dim_num - 1) @@ -350,6 +351,18 @@ def quant_matmul_shape_check(x1, x2, scale, offset): offset_first_dim == 1 or offset_first_dim == x2_n_dim, lambda: "the offset 1st dim value must be 1 or x2 n dim value, please check offset 1st dim value", ) + if pertoken_scale is not None: + pertoken_scale_dim_num = pertoken_scale.dim() + torch._check( + pertoken_scale_dim_num == 1, + lambda: "the pertoken_scale dim num must be 1, please check scale dim num", + ) + pertoken_scale_first_dim = pertoken_scale.size(0) + torch._check( + pertoken_scale_first_dim == x1_m_dim, + lambda: "the pertoken_scale 1st dim value must be x1 m dim value, please check scale 1st dim value ", + ) + scale_dim_num = scale.dim() torch._check( scale_dim_num == 1, @@ -362,7 +375,8 @@ def quant_matmul_shape_check(x1, x2, scale, offset): ) -def quant_matmul_dtype_check(x1, x2, scale, offset, bias): +def quant_matmul_dtype_check(*args): + x1, x2, scale, offset, pertoken_scale, bias = args torch._check( x1.dtype == torch.int8 and x2.dtype == torch.int8, lambda: "x1'type and x2's type should be int8, but x1.dtype is " + str(x1.dtype) + " and x2.dtype is " + @@ -378,6 +392,11 @@ def quant_matmul_dtype_check(x1, x2, scale, offset, bias): offset.dtype == torch.float32, lambda: "offset's type supported for float32, but offset.dtype is " + str(offset.dtype), ) + if pertoken_scale is not None: + torch._check( + pertoken_scale.dtype == torch.float32, + lambda: "pertoken_scale's type supported for float32, but pertoken_scale.dtype is " + str(offset.dtype), + ) if bias is not None: torch._check( bias.dtype == torch.int32 or bias.dtype == torch.bfloat16, @@ -385,13 +404,13 @@ def quant_matmul_dtype_check(x1, x2, scale, offset, bias): ) -def quant_matmul_scale_offset_out_check(scale, offset, output_dtype): +def quant_matmul_scale_offset_out_check(scale, offset, pertoken_scale, output_dtype): if scale.dtype == torch.bfloat16: torch._check( - output_dtype == "bfloat16", - lambda: "When scale's dtype is bfloat16, output_dtype must be bfloat16, but output_dtype is " + output_dtype, + output_dtype == torch.bfloat16, + lambda: "When scale's dtype is bfloat16, output_dtype must be bfloat16, but output_dtype is " + str(output_dtype), ) - if output_dtype == "bfloat16": + if output_dtype == torch.bfloat16: torch._check( scale.dtype == torch.bfloat16, lambda: "When output_dtype is bfloat16, scale's dtype must be bfloat16, but scale's dtype is " + @@ -399,13 +418,24 @@ def quant_matmul_scale_offset_out_check(scale, offset, output_dtype): ) if offset is not None: torch._check( - output_dtype is None or output_dtype == "int8", - lambda: "offset only exists when output_dtype is int8, but output_dtype is " + output_dtype, + output_dtype is None or output_dtype == torch.int8, + lambda: "offset only exists when output_dtype is int8, but output_dtype is " + str(output_dtype), + ) + if pertoken_scale is not None: + if output_dtype == torch.float16: + torch._check( + scale.dtype == torch.float32, + lambda: "When output_dtype is float16 and pertoken_scale is not none, scale's dtype must be float32, but scale's dtype is " + + str(scale.dtype), + ) + torch._check( + output_dtype == torch.float16 or output_dtype == torch.bfloat16, + lambda: "When pertoken_scale is not none, output_dtype must be float16 or bfloat16, but output_dtype is " + str(output_dtype), ) @impl(m, "npu_quant_matmul") -def npu_quant_matmul_meta(x1, x2, scale, offset=None, bias=None, output_dtype=None): +def npu_quant_matmul_meta(x1, x2, scale, *, offset=None, pertoken_scale=None, bias=None, output_dtype=None): batch_val = 1 x1_dim_num = x1.dim() x2_dim_num = x2.dim() @@ -428,7 +458,7 @@ def npu_quant_matmul_meta(x1, x2, scale, offset=None, bias=None, output_dtype=No dimn = x2.size(x2.dim() - 1) dim_list.append(dimm) dim_list.append(dimn) - quant_matmul_shape_check(x1, x2, scale, offset) + quant_matmul_shape_check(x1, x2, scale, offset, pertoken_scale) if bias is not None: if bias.dim() == 3: torch._check( @@ -436,13 +466,16 @@ def npu_quant_matmul_meta(x1, x2, scale, offset=None, bias=None, output_dtype=No lambda:"when bias dim is 3, out dim need to be 3", ) bias_shape_check(x2, bias, batch_val) - quant_matmul_dtype_check(x1, x2, scale, offset, bias) - quant_matmul_scale_offset_out_check(scale, offset, output_dtype) - if output_dtype == "float16": + quant_matmul_dtype_check(x1, x2, scale, offset, pertoken_scale, bias) + quant_matmul_scale_offset_out_check(scale, offset, pertoken_scale, output_dtype) + if output_dtype == torch.float16: return shape_long.new_empty(tuple(dim_list), dtype=torch.float16) - elif output_dtype == "bfloat16": + elif output_dtype == torch.bfloat16: return shape_long.new_empty(tuple(dim_list), dtype=torch.bfloat16) - return shape_long.new_empty(tuple(dim_list), dtype=torch.int8) + elif output_dtype is None or output_dtype == torch.int8: + return shape_long.new_empty(tuple(dim_list), dtype=torch.int8) + else: + raise RuntimeError("Not supportted output dtype is " + str(output_dtype)) @impl(m, "npu_trans_quant_param")