From 05dddabe5baf3b5eb56e4e236090898b5584ae77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=85=B4=E5=AE=87?= Date: Fri, 22 Mar 2024 16:33:28 +0800 Subject: [PATCH] =?UTF-8?q?fixed=204ca53fd=20from=20https://gitee.com/chen?= =?UTF-8?q?-xing-yu/pytorch/pulls/10543=20=E4=B8=8D=E5=85=BC=E5=AE=B9?= =?UTF-8?q?=E5=8F=98=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/contrib/test_linear_a8w8_quant.py | 2 +- test/test_fake_tensor.py | 4 +- torch_npu/contrib/module/linear_a8w8_quant.py | 18 ++++-- torch_npu/meta/meta_registrations.py | 63 ++++++++++++++----- 4 files changed, 65 insertions(+), 22 deletions(-) diff --git a/test/contrib/test_linear_a8w8_quant.py b/test/contrib/test_linear_a8w8_quant.py index 27ff64d1e5..aa32865b33 100644 --- a/test/contrib/test_linear_a8w8_quant.py +++ b/test/contrib/test_linear_a8w8_quant.py @@ -10,7 +10,7 @@ DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] class TestLinearA8W8Quant(TestCase): def npu_linear_quant(self, in_features, out_features, x1, x2, scale): - model = LinearA8W8Quant(in_features, out_features, False) + model = LinearA8W8Quant(in_features, out_features, bias=False, pertoken_scale=False, offset=False) model = model.npu() model.weight.data = x2 model.scale.data = scale diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 296fa17925..0ce223f352 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1686,13 +1686,13 @@ class TestQuantMatmul(TestCase): scale = torch.randn(1, dtype=torch.float32).npu() offset = torch.randn(1, dtype=torch.float32).npu() bias = torch.randint(-1, -1, (1, 1, 100), dtype=torch.int32).npu() - res = torch_npu.npu_quant_matmul(x1, x2, scale, offset, bias) + res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias) self.assertTrue(expect_ret.shape == res.shape) self.assertTrue(expect_ret.dtype == res.dtype) expect_ret_bf16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.bfloat16).npu() scale_bf16 = torch.randn(1, dtype=torch.bfloat16).npu() - res_bf16 = torch_npu.npu_quant_matmul(x1, x2, scale_bf16, None, bias, "bfloat16") + res_bf16 = torch_npu.npu_quant_matmul(x1, x2, scale_bf16, offset=None, bias=bias, output_dtype=torch.bfloat16) self.assertTrue(expect_ret_bf16.shape == res_bf16.shape) self.assertTrue(expect_ret_bf16.dtype == res_bf16.dtype) diff --git a/torch_npu/contrib/module/linear_a8w8_quant.py b/torch_npu/contrib/module/linear_a8w8_quant.py index 92dc193eb4..9dd285d3ed 100644 --- a/torch_npu/contrib/module/linear_a8w8_quant.py +++ b/torch_npu/contrib/module/linear_a8w8_quant.py @@ -30,6 +30,7 @@ class LinearA8W8Quant(nn.Module): :math:`k = \frac{1}{\text{in\_features}}` scale: quant matmul calculation parameter offset: quant matmul calculation parameter + pertoken_scale: inverse quant matmul calculation parameter bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where @@ -52,9 +53,11 @@ class LinearA8W8Quant(nn.Module): weight: Tensor scale: Tensor offset: Tensor + pertoken_scale: Tensor + bias: Tensor - def __init__(self, in_features: int, out_features: int, bias: bool = True, offset: bool = False, - device=None, dtype=None, output_dtype=None) -> None: + def __init__(self, in_features: int, out_features: int, *, bias: bool = True, offset: bool = False, + pertoken_scale: bool = False, device=None, dtype=None, output_dtype=None) -> None: super(LinearA8W8Quant, self).__init__() self.in_features = in_features @@ -66,6 +69,12 @@ class LinearA8W8Quant(nn.Module): self.offset = Parameter(torch.empty(out_features, dtype=torch.float32), False) else: self.register_parameter('offset', None) + + if pertoken_scale: + self.pertoken_scale = Parameter(torch.empty(out_features, dtype=torch.float32), False) + else: + self.register_parameter('pertoken_scale', None) + if bias: self.bias = Parameter(torch.empty(out_features, dtype=torch.int32), False) else: @@ -75,7 +84,8 @@ class LinearA8W8Quant(nn.Module): scale_quant = self.scale first_last_dim = self.weight.dim() - 1 second_last_dim = self.weight.dim() - 2 - if self.scale.dtype == torch.float32: + if self.scale.dtype == torch.float32 and self.pertoken_scale is None: scale_quant = torch_npu.npu_trans_quant_param(self.scale, self.offset) + return torch_npu.npu_quant_matmul(linear_quant_input, self.weight.transpose(second_last_dim, first_last_dim), - scale_quant, self.offset, self.bias, self.output_dtype) + scale_quant, offset=self.offset, pertoken_scale=self.pertoken_scale, bias=self.bias, output_dtype=self.output_dtype) diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 33e813a3c8..3d0e2b479b 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -352,11 +352,12 @@ def bias_shape_check(x2, bias, batch_val): ) -def quant_matmul_shape_check(x1, x2, scale, offset): +def quant_matmul_shape_check(x1, x2, scale, offset, pertoken_scale): X_MAX_DIM = 6 X_MIN_DIM = 2 x1_dim_num = x1.dim() x2_dim_num = x2.dim() + x1_m_dim = x1.size(x1_dim_num - 2) x1_k_dim = x1.size(x1_dim_num - 1) x2_k_dim = x2.size(x2_dim_num - 2) x2_n_dim = x2.size(x2_dim_num - 1) @@ -383,6 +384,18 @@ def quant_matmul_shape_check(x1, x2, scale, offset): offset_first_dim == 1 or offset_first_dim == x2_n_dim, lambda: "the offset 1st dim value must be 1 or x2 n dim value, please check offset 1st dim value", ) + if pertoken_scale is not None: + pertoken_scale_dim_num = pertoken_scale.dim() + torch._check( + pertoken_scale_dim_num == 1, + lambda: "the pertoken_scale dim num must be 1, please check scale dim num", + ) + pertoken_scale_first_dim = pertoken_scale.size(0) + torch._check( + pertoken_scale_first_dim == x1_m_dim, + lambda: "the pertoken_scale 1st dim value must be x1 m dim value, please check scale 1st dim value ", + ) + scale_dim_num = scale.dim() torch._check( scale_dim_num == 1, @@ -395,7 +408,8 @@ def quant_matmul_shape_check(x1, x2, scale, offset): ) -def quant_matmul_dtype_check(x1, x2, scale, offset, bias): +def quant_matmul_dtype_check(*args): + x1, x2, scale, offset, pertoken_scale, bias = args torch._check( x1.dtype == torch.int8 and x2.dtype == torch.int8, lambda: "x1'type and x2's type should be int8, but x1.dtype is " + str(x1.dtype) + " and x2.dtype is " + @@ -411,6 +425,11 @@ def quant_matmul_dtype_check(x1, x2, scale, offset, bias): offset.dtype == torch.float32, lambda: "offset's type supported for float32, but offset.dtype is " + str(offset.dtype), ) + if pertoken_scale is not None: + torch._check( + pertoken_scale.dtype == torch.float32, + lambda: "pertoken_scale's type supported for float32, but pertoken_scale.dtype is " + str(offset.dtype), + ) if bias is not None: torch._check( bias.dtype == torch.int32 or bias.dtype == torch.bfloat16, @@ -418,13 +437,13 @@ def quant_matmul_dtype_check(x1, x2, scale, offset, bias): ) -def quant_matmul_scale_offset_out_check(scale, offset, output_dtype): +def quant_matmul_scale_offset_out_check(scale, offset, pertoken_scale, output_dtype): if scale.dtype == torch.bfloat16: torch._check( - output_dtype == "bfloat16", - lambda: "When scale's dtype is bfloat16, output_dtype must be bfloat16, but output_dtype is " + output_dtype, + output_dtype == torch.bfloat16, + lambda: "When scale's dtype is bfloat16, output_dtype must be bfloat16, but output_dtype is " + str(output_dtype), ) - if output_dtype == "bfloat16": + if output_dtype == torch.bfloat16: torch._check( scale.dtype == torch.bfloat16, lambda: "When output_dtype is bfloat16, scale's dtype must be bfloat16, but scale's dtype is " + @@ -432,13 +451,24 @@ def quant_matmul_scale_offset_out_check(scale, offset, output_dtype): ) if offset is not None: torch._check( - output_dtype is None or output_dtype == "int8", - lambda: "offset only exists when output_dtype is int8, but output_dtype is " + output_dtype, + output_dtype is None or output_dtype == torch.int8, + lambda: "offset only exists when output_dtype is int8, but output_dtype is " + str(output_dtype), + ) + if pertoken_scale is not None: + if output_dtype == torch.float16: + torch._check( + scale.dtype == torch.float32, + lambda: "When output_dtype is float16 and pertoken_scale is not none, scale's dtype must be float32, but scale's dtype is " + + str(scale.dtype), + ) + torch._check( + output_dtype == torch.float16 or output_dtype == torch.bfloat16, + lambda: "When pertoken_scale is not none, output_dtype must be float16 or bfloat16, but output_dtype is " + str(output_dtype), ) @impl(m, "npu_quant_matmul") -def npu_quant_matmul_meta(x1, x2, scale, offset=None, bias=None, output_dtype=None): +def npu_quant_matmul_meta(x1, x2, scale, *, offset=None, pertoken_scale=None, bias=None, output_dtype=None): batch_val = 1 x1_dim_num = x1.dim() x2_dim_num = x2.dim() @@ -461,7 +491,7 @@ def npu_quant_matmul_meta(x1, x2, scale, offset=None, bias=None, output_dtype=No dimn = x2.size(x2.dim() - 1) dim_list.append(dimm) dim_list.append(dimn) - quant_matmul_shape_check(x1, x2, scale, offset) + quant_matmul_shape_check(x1, x2, scale, offset, pertoken_scale) if bias is not None: if bias.dim() == 3: torch._check( @@ -469,13 +499,16 @@ def npu_quant_matmul_meta(x1, x2, scale, offset=None, bias=None, output_dtype=No lambda:"when bias dim is 3, out dim need to be 3", ) bias_shape_check(x2, bias, batch_val) - quant_matmul_dtype_check(x1, x2, scale, offset, bias) - quant_matmul_scale_offset_out_check(scale, offset, output_dtype) - if output_dtype == "float16": + quant_matmul_dtype_check(x1, x2, scale, offset, pertoken_scale, bias) + quant_matmul_scale_offset_out_check(scale, offset, pertoken_scale, output_dtype) + if output_dtype == torch.float16: return shape_long.new_empty(tuple(dim_list), dtype=torch.float16) - elif output_dtype == "bfloat16": + elif output_dtype == torch.bfloat16: return shape_long.new_empty(tuple(dim_list), dtype=torch.bfloat16) - return shape_long.new_empty(tuple(dim_list), dtype=torch.int8) + elif output_dtype is None or output_dtype == torch.int8: + return shape_long.new_empty(tuple(dim_list), dtype=torch.int8) + else: + raise RuntimeError("Not supportted output dtype is " + str(output_dtype)) @impl(m, "npu_trans_quant_param") -- Gitee