From a2f0614386613b3113c2007fd037e0f3e7a8d7bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=85=B4=E5=AE=87?= Date: Thu, 21 Mar 2024 20:59:40 +0800 Subject: [PATCH] v2.1.0 --- test/contrib/test_linear_a8w8_quant.py | 2 +- test/test_fake_tensor.py | 4 +- torch_npu/contrib/module/linear_a8w8_quant.py | 24 ++++--- torch_npu/meta/meta_registrations.py | 63 ++++++++++++++----- 4 files changed, 68 insertions(+), 25 deletions(-) diff --git a/test/contrib/test_linear_a8w8_quant.py b/test/contrib/test_linear_a8w8_quant.py index 27ff64d1e50..aa32865b33b 100644 --- a/test/contrib/test_linear_a8w8_quant.py +++ b/test/contrib/test_linear_a8w8_quant.py @@ -10,7 +10,7 @@ DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] class TestLinearA8W8Quant(TestCase): def npu_linear_quant(self, in_features, out_features, x1, x2, scale): - model = LinearA8W8Quant(in_features, out_features, False) + model = LinearA8W8Quant(in_features, out_features, bias=False, pertoken_scale=False, offset=False) model = model.npu() model.weight.data = x2 model.scale.data = scale diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 7ff54f378bc..08de43756f0 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1689,13 +1689,13 @@ class TestQuantMatmul(TestCase): scale = torch.randn(1, dtype=torch.float32).npu() offset = torch.randn(1, dtype=torch.float32).npu() bias = torch.randint(-1, -1, (1, 1, 100), dtype=torch.int32).npu() - res = torch_npu.npu_quant_matmul(x1, x2, scale, offset, bias) + res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias) self.assertTrue(expect_ret.shape == res.shape) self.assertTrue(expect_ret.dtype == res.dtype) expect_ret_bf16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.bfloat16).npu() scale_bf16 = torch.randn(1, dtype=torch.bfloat16).npu() - res_bf16 = torch_npu.npu_quant_matmul(x1, x2, scale_bf16, None, bias, "bfloat16") + res_bf16 = torch_npu.npu_quant_matmul(x1, x2, scale_bf16, offset=None, bias=bias, output_dtype=torch.bfloat16) self.assertTrue(expect_ret_bf16.shape == res_bf16.shape) self.assertTrue(expect_ret_bf16.dtype == res_bf16.dtype) diff --git a/torch_npu/contrib/module/linear_a8w8_quant.py b/torch_npu/contrib/module/linear_a8w8_quant.py index 218dfe4196c..9dd285d3edf 100644 --- a/torch_npu/contrib/module/linear_a8w8_quant.py +++ b/torch_npu/contrib/module/linear_a8w8_quant.py @@ -30,6 +30,7 @@ class LinearA8W8Quant(nn.Module): :math:`k = \frac{1}{\text{in\_features}}` scale: quant matmul calculation parameter offset: quant matmul calculation parameter + pertoken_scale: inverse quant matmul calculation parameter bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where @@ -52,9 +53,11 @@ class LinearA8W8Quant(nn.Module): weight: Tensor scale: Tensor offset: Tensor + pertoken_scale: Tensor + bias: Tensor - def __init__(self, in_features: int, out_features: int, bias: bool = True, offset: bool = False, - device=None, dtype=None, output_dtype=None) -> None: + def __init__(self, in_features: int, out_features: int, *, bias: bool = True, offset: bool = False, + pertoken_scale: bool = False, device=None, dtype=None, output_dtype=None) -> None: super(LinearA8W8Quant, self).__init__() self.in_features = in_features @@ -66,6 +69,12 @@ class LinearA8W8Quant(nn.Module): self.offset = Parameter(torch.empty(out_features, dtype=torch.float32), False) else: self.register_parameter('offset', None) + + if pertoken_scale: + self.pertoken_scale = Parameter(torch.empty(out_features, dtype=torch.float32), False) + else: + self.register_parameter('pertoken_scale', None) + if bias: self.bias = Parameter(torch.empty(out_features, dtype=torch.int32), False) else: @@ -73,9 +82,10 @@ class LinearA8W8Quant(nn.Module): def forward(self, linear_quant_input: Tensor) -> Tensor: scale_quant = self.scale - weight_k_dim = self.weight.dim() - 1 - weight_n_dim = self.weight.dim() - 2 - if self.scale.dtype == torch.float32: + first_last_dim = self.weight.dim() - 1 + second_last_dim = self.weight.dim() - 2 + if self.scale.dtype == torch.float32 and self.pertoken_scale is None: scale_quant = torch_npu.npu_trans_quant_param(self.scale, self.offset) - return torch_npu.npu_quant_matmul(linear_quant_input, self.weight.transpose(weight_n_dim, weight_k_dim), - scale_quant, self.offset, self.bias, self.output_dtype) + + return torch_npu.npu_quant_matmul(linear_quant_input, self.weight.transpose(second_last_dim, first_last_dim), + scale_quant, offset=self.offset, pertoken_scale=self.pertoken_scale, bias=self.bias, output_dtype=self.output_dtype) diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 614564bb397..bc78465a65c 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -383,11 +383,12 @@ def bias_shape_check(x2, bias, batch_val): ) -def quant_matmul_shape_check(x1, x2, scale, offset): +def quant_matmul_shape_check(x1, x2, scale, offset, pertoken_scale): X_MAX_DIM = 6 X_MIN_DIM = 2 x1_dim_num = x1.dim() x2_dim_num = x2.dim() + x1_m_dim = x1.size(x1_dim_num - 2) x1_k_dim = x1.size(x1_dim_num - 1) x2_k_dim = x2.size(x2_dim_num - 2) x2_n_dim = x2.size(x2_dim_num - 1) @@ -414,6 +415,18 @@ def quant_matmul_shape_check(x1, x2, scale, offset): offset_first_dim == 1 or offset_first_dim == x2_n_dim, lambda: "the offset 1st dim value must be 1 or x2 n dim value, please check offset 1st dim value", ) + if pertoken_scale is not None: + pertoken_scale_dim_num = pertoken_scale.dim() + torch._check( + pertoken_scale_dim_num == 1, + lambda: "the pertoken_scale dim num must be 1, please check scale dim num", + ) + pertoken_scale_first_dim = pertoken_scale.size(0) + torch._check( + pertoken_scale_first_dim == x1_m_dim, + lambda: "the pertoken_scale 1st dim value must be x1 m dim value, please check scale 1st dim value ", + ) + scale_dim_num = scale.dim() torch._check( scale_dim_num == 1, @@ -426,7 +439,8 @@ def quant_matmul_shape_check(x1, x2, scale, offset): ) -def quant_matmul_dtype_check(x1, x2, scale, offset, bias): +def quant_matmul_dtype_check(*args): + x1, x2, scale, offset, pertoken_scale, bias = args torch._check( x1.dtype == torch.int8 and x2.dtype == torch.int8, lambda: "x1'type and x2's type should be int8, but x1.dtype is " + str(x1.dtype) + " and x2.dtype is " + @@ -442,6 +456,11 @@ def quant_matmul_dtype_check(x1, x2, scale, offset, bias): offset.dtype == torch.float32, lambda: "offset's type supported for float32, but offset.dtype is " + str(offset.dtype), ) + if pertoken_scale is not None: + torch._check( + pertoken_scale.dtype == torch.float32, + lambda: "pertoken_scale's type supported for float32, but pertoken_scale.dtype is " + str(offset.dtype), + ) if bias is not None: torch._check( bias.dtype == torch.int32 or bias.dtype == torch.bfloat16, @@ -449,13 +468,13 @@ def quant_matmul_dtype_check(x1, x2, scale, offset, bias): ) -def quant_matmul_scale_offset_out_check(scale, offset, output_dtype): +def quant_matmul_scale_offset_out_check(scale, offset, pertoken_scale, output_dtype): if scale.dtype == torch.bfloat16: torch._check( - output_dtype == "bfloat16", - lambda: "When scale's dtype is bfloat16, output_dtype must be bfloat16, but output_dtype is " + output_dtype, + output_dtype == torch.bfloat16, + lambda: "When scale's dtype is bfloat16, output_dtype must be bfloat16, but output_dtype is " + str(output_dtype), ) - if output_dtype == "bfloat16": + if output_dtype == torch.bfloat16: torch._check( scale.dtype == torch.bfloat16, lambda: "When output_dtype is bfloat16, scale's dtype must be bfloat16, but scale's dtype is " + @@ -463,13 +482,24 @@ def quant_matmul_scale_offset_out_check(scale, offset, output_dtype): ) if offset is not None: torch._check( - output_dtype is None or output_dtype == "int8", - lambda: "offset only exists when output_dtype is int8, but output_dtype is " + output_dtype, + output_dtype is None or output_dtype == torch.int8, + lambda: "offset only exists when output_dtype is int8, but output_dtype is " + str(output_dtype), + ) + if pertoken_scale is not None: + if output_dtype == torch.float16: + torch._check( + scale.dtype == torch.float32, + lambda: "When output_dtype is float16 and pertoken_scale is not none, scale's dtype must be float32, but scale's dtype is " + + str(scale.dtype), + ) + torch._check( + output_dtype == torch.float16 or output_dtype == torch.bfloat16, + lambda: "When pertoken_scale is not none, output_dtype must be float16 or bfloat16, but output_dtype is " + str(output_dtype), ) @impl(m, "npu_quant_matmul") -def npu_quant_matmul_meta(x1, x2, scale, offset=None, bias=None, output_dtype=None): +def npu_quant_matmul_meta(x1, x2, scale, *, offset=None, pertoken_scale=None, bias=None, output_dtype=None): batch_val = 1 x1_dim_num = x1.dim() x2_dim_num = x2.dim() @@ -492,7 +522,7 @@ def npu_quant_matmul_meta(x1, x2, scale, offset=None, bias=None, output_dtype=No dimn = x2.size(x2.dim() - 1) dim_list.append(dimm) dim_list.append(dimn) - quant_matmul_shape_check(x1, x2, scale, offset) + quant_matmul_shape_check(x1, x2, scale, offset, pertoken_scale) if bias is not None: if bias.dim() == 3: torch._check( @@ -500,13 +530,16 @@ def npu_quant_matmul_meta(x1, x2, scale, offset=None, bias=None, output_dtype=No lambda:"when bias dim is 3, out dim need to be 3", ) bias_shape_check(x2, bias, batch_val) - quant_matmul_dtype_check(x1, x2, scale, offset, bias) - quant_matmul_scale_offset_out_check(scale, offset, output_dtype) - if output_dtype == "float16": + quant_matmul_dtype_check(x1, x2, scale, offset, pertoken_scale, bias) + quant_matmul_scale_offset_out_check(scale, offset, pertoken_scale, output_dtype) + if output_dtype == torch.float16: return shape_long.new_empty(tuple(dim_list), dtype=torch.float16) - elif output_dtype == "bfloat16": + elif output_dtype == torch.bfloat16: return shape_long.new_empty(tuple(dim_list), dtype=torch.bfloat16) - return shape_long.new_empty(tuple(dim_list), dtype=torch.int8) + elif output_dtype is None or output_dtype == torch.int8: + return shape_long.new_empty(tuple(dim_list), dtype=torch.int8) + else: + raise RuntimeError("Not supportted output dtype is " + str(output_dtype)) @impl(m, "npu_trans_quant_param") -- Gitee