From 1672aa9d389ed2b2d30c219481efc2164972c4e2 Mon Sep 17 00:00:00 2001 From: zhuxiaochen Date: Tue, 11 Mar 2025 14:45:09 +0800 Subject: [PATCH] add osp --- mindspore_gs/ptq/context.py | 1 + mindspore_gs/ptq/ptq/algorithms/clipper.py | 5 +- mindspore_gs/ptq/ptq/hal.py | 423 +++++++++++++++++- .../ptq/ptq/wrappers/mindformers/__init__.py | 3 +- .../mindformers/linear_all_quant_wrappers.py | 18 +- .../mindformers/linear_smooth_wrappers.py | 248 +++++++++- tests/st/ptq/ptq/ptq_network_runner.py | 20 + tests/st/ptq/ptq/test_ptq.py | 2 +- 8 files changed, 693 insertions(+), 27 deletions(-) diff --git a/mindspore_gs/ptq/context.py b/mindspore_gs/ptq/context.py index 92d63798..22191be2 100644 --- a/mindspore_gs/ptq/context.py +++ b/mindspore_gs/ptq/context.py @@ -117,6 +117,7 @@ class InnerPTQConfig(GSBaseConfig, PTQConfig): layer_quant_info_collect: dict = field(default_factory=dict) algorithm_cache_path: str = '' always_use_fp_input_in_processer: bool = False + use_inner_osp: bool = False dump_path: str = "" dumper: Dumper = Dumper() diff --git a/mindspore_gs/ptq/ptq/algorithms/clipper.py b/mindspore_gs/ptq/ptq/algorithms/clipper.py index 97e14df9..bab2fe92 100644 --- a/mindspore_gs/ptq/ptq/algorithms/clipper.py +++ b/mindspore_gs/ptq/ptq/algorithms/clipper.py @@ -67,8 +67,11 @@ class LinearClipper(Algorithm): if not LinearClipper.linear_map.get(type(cell)): return cell, False layer_policy = self.handler.get_layer_policy(cell_name) + is_inner_osp = layer_policy.outliers_suppression == OutliersSuppressionType.OUTLIER_SUPPRESSION_PLUS \ + and layer_policy.use_inner_osp is_satisfied = layer_policy.outliers_suppression == OutliersSuppressionType.AWQ or \ - layer_policy.outliers_suppression == OutliersSuppressionType.OUTLIER_SUPPRESSION_LITE + layer_policy.outliers_suppression == OutliersSuppressionType.OUTLIER_SUPPRESSION_LITE or \ + is_inner_osp if not layer_policy or not is_satisfied: return cell, False if (any(opname in cell_name for opname in layer_policy.opname_blacklist) or diff --git a/mindspore_gs/ptq/ptq/hal.py b/mindspore_gs/ptq/ptq/hal.py index 11d1c05f..5c2f8bd4 100644 --- a/mindspore_gs/ptq/ptq/hal.py +++ b/mindspore_gs/ptq/ptq/hal.py @@ -17,6 +17,7 @@ import abc import dataclasses import enum from typing import Optional +import warnings import numpy as np from mindspore import Tensor, dtype @@ -28,6 +29,7 @@ from mindspore.common.initializer import initializer from mindspore.ops.operations.comm_ops import ReduceOp from mindspore.communication.management import GlobalComm from mindspore.ops.auto_generate import WeightQuantBatchMatmul, QuantBatchMatmul, DynamicQuantExt, GroupedMatmulV4 +from mindspore_gs.common.utils import value_check from mindspore_gs.common.numpy_quant_common import NumpyQuantOps from mindspore_gs.common import logger from mindspore_gs.ptq import PTQMode @@ -324,6 +326,160 @@ class QuantWithOutlierSuppressionPlusHighPerformance(QuantWithOutlierSuppression return self.quant(x, self.input_scale, self.input_zp, False, "ROUND", dtype.int8) +# for osp of omniquant +class QuantWithOutlierSuppressionPlusSmooth(QuantUnitCell): + """QuantWithOutlierSuppressionPlus""" + def __init__(self, layer_name, parallel_type: ParallelType): + super().__init__(layer_name) + self.parallel_type = parallel_type + self.input_scale = None + self.input_zp = None + self.smooth_scale = None + self.beta_osp = None + + @staticmethod + def create( + layer_name, + x_qparam: QuantParam, + ic, + dst_dtype, + is_deploy, + parallel_type: ParallelType, + smooth_scale, + beta_osp, + kernel_type: KernelType = KernelType.ASD, + ): + """create""" + if kernel_type in (KernelType.ASD, KernelType.INTERNAL): + return QuantWithOutlierSuppressionPlusHighPerformanceSmooth( + layer_name, x_qparam, ic, dst_dtype, is_deploy, parallel_type, smooth_scale, beta_osp, dtype.int8 + ) + if kernel_type is KernelType.ACLNN: + return QuantWithOutlierSuppressionPlusHighPerformanceSmooth( + layer_name, x_qparam, ic, dst_dtype, is_deploy, parallel_type, smooth_scale, beta_osp, dtype.float16 + ) + if kernel_type is KernelType.HIGH_PRECISION: + return QuantWithOutlierSuppressionPlusHighPrecisionSmooth( + layer_name, x_qparam, ic, dst_dtype, is_deploy, parallel_type, smooth_scale, beta_osp + ) + raise RuntimeError(f"Not supported kernel type: {kernel_type}") + + def param_shard_state(self, tensor_parallel_num=1, **kwargs) -> dict: + if self.input_scale.shape == (1,) or self.parallel_type == ParallelType.COL_PARALLEL: + input_scale_shard = (1,) + input_zp_shard = (1,) + beta_shard = (1,) + smooth_scale_shard = (1,) + elif self.parallel_type == ParallelType.ROW_PARALLEL: + input_scale_shard = (tensor_parallel_num,) + input_zp_shard = (tensor_parallel_num,) + beta_shard = (tensor_parallel_num,) + smooth_scale_shard = (tensor_parallel_num,) + else: + return {} + return {self.input_scale.name: {'shape': self.input_scale.shape, 'shard': input_scale_shard}, + self.input_zp.name: {'shape': self.input_zp.shape, 'shard': input_zp_shard}, + self.smooth_scale.name: {'shape': self.smooth_scale.shape, 'shard': smooth_scale_shard}, + self.beta_osp.name: {'shape': self.beta_osp.shape, 'shard': beta_shard}} + + +class QuantWithOutlierSuppressionPlusHighPrecisionSmooth(QuantWithOutlierSuppressionPlusSmooth): + """QuantWithOutlierSuppressionPlusHighPrecisionSmooth""" + # pylint: disable=unused-argument + def __init__(self, layer_name, x_qparam: QuantParam, ic, dst_dtype, is_deploy, parallel_type: ParallelType, + smooth_scale, beta_osp): + super().__init__(layer_name, parallel_type) + self.beta_osp = beta_osp + self.smooth_scale = smooth_scale.asnumpy() + self.dst_dtype = dst_dtype + if is_deploy: + self.input_scale = Parameter(initializer('ones', (ic,), dtype.float64)) + self.input_zp = Parameter(initializer('zeros', (ic,), dtype.float64)) + return + # fuse smooth.mul and quant + input_scale_np = x_qparam.scale.asnumpy().astype(np.float64) + if smooth_scale is not None: + final_scale_np = input_scale_np / smooth_scale.astype(np.float64) + logger.debug(f"QuantWithSmoothHighPrecision: input scale with smooth scale of Layer({parallel_type}:" + f"{layer_name}) is {{{final_scale_np.shape}, {final_scale_np.dtype}, {final_scale_np}}}") + else: + if input_scale_np.shape == (1,): # aclnn quant op not support pertensor + final_scale_np = np.tile(input_scale_np, ic) + logger.debug(f"QuantWithSmoothHighPrecision: input scale from vector of Layer({parallel_type}:" + f"{layer_name}) is {{{final_scale_np.shape}, {final_scale_np.dtype}, {final_scale_np}}}") + else: + final_scale_np = input_scale_np + logger.debug(f"QuantWithSmoothHighPrecision: input scale of Layer({parallel_type}:{layer_name}) is " + f"{{{final_scale_np.shape}, {final_scale_np.dtype}, {final_scale_np}}}") + self.input_scale = Parameter(Tensor(final_scale_np, dtype=dtype.float64)) + + if self.input_scale.shape != x_qparam.zero_point.shape: + if isinstance(x_qparam.zero_point, np.number): + raise RuntimeError("Shape of scale and zero point are not compatible.") + self.input_zp = Parameter(Tensor(np.tile(x_qparam.zero_point, ic).astype(np.float64), dtype=dtype.float64)) + logger.debug(f"QuantWithSmoothHighPrecision: input zp from vector of Layer({parallel_type}:{layer_name}) is" + f" {{{self.input_zp.shape}, {self.input_zp.dtype}, {self.input_zp}}}") + else: + self.input_zp = Parameter(Tensor(x_qparam.zero_point, dtype=dtype.float64)) + logger.debug(f"QuantWithSmoothHighPrecision: input zp of Layer({parallel_type}:{layer_name}) is " + f"{{{self.input_zp.shape}, {self.input_zp.dtype}, {self.input_zp.asnumpy()}}}") + + def construct(self, x): + x = msops.add(x, self.beta_osp) + out = x / self.input_scale + out = msops.round(out) + out = msops.clip(out, -128., 127.) + return msops.cast(out, self.dst_dtype) + +class QuantWithOutlierSuppressionPlusHighPerformanceSmooth(QuantWithOutlierSuppressionPlusSmooth): + """QuantWithOutlierSuppressionPlusHighPerformanceSmooth""" + # pylint: disable=unused-argument + def __init__(self, layer_name, x_qparam: QuantParam, ic, dst_dtype, is_deploy, parallel_type: ParallelType, + smooth_scale, beta_osp, zp_dtype=dtype.int8): + super().__init__(layer_name, parallel_type) + self.smooth_scale = smooth_scale.asnumpy() + self.quant = QuantV2() + self.beta_osp = beta_osp + self.is_perchannel = smooth_scale is not None + if is_deploy: + if not self.is_perchannel: + self.input_scale = Parameter(initializer('ones', (1,), dst_dtype)) + self.input_zp = Parameter(initializer('zeros', (1,), zp_dtype)) + else: + self.input_scale = Parameter(initializer('ones', (ic,), dst_dtype)) + self.input_zp = Parameter(initializer('zeros', (ic,), zp_dtype)) + return + # fuse smooth.mul and quant + input_scale_np = x_qparam.scale.asnumpy().astype(np.float16) + if self.is_perchannel: + final_scale_np = input_scale_np / smooth_scale.astype(np.float16) + logger.debug(f"QuantWithSmoothHighPerformance: input scale with smooth scale of Layer({parallel_type}:" + f"{layer_name}) is {{{final_scale_np.shape}, {final_scale_np.dtype}, {final_scale_np}}}") + else: + final_scale_np = input_scale_np + logger.debug(f"QuantWithSmoothHighPerformance: input scale from vector of Layer({parallel_type}:" + f"{layer_name}) is {{{final_scale_np.shape}, {final_scale_np.dtype}, {final_scale_np}}}") + if zp_dtype != dtype.int8: # input_scale need divide 1 for aclnn quant ops + self.input_scale = Parameter(Tensor(1 / final_scale_np, dtype=dst_dtype)) + else: + self.input_scale = Parameter(Tensor(final_scale_np, dtype=dst_dtype)) + + + if self.input_scale.shape != x_qparam.zero_point.shape: + if isinstance(x_qparam.zero_point, np.number): + raise RuntimeError("Shape of scale and zero point are not compatible.") + self.input_zp = Parameter(Tensor(np.tile(x_qparam.zero_point, ic).astype(np.float16), dtype=zp_dtype)) + logger.debug(f"QuantWithSmoothHighPerformance: input zp from vector of Layer({parallel_type}:{layer_name})" + f" is {{{self.input_zp.shape}, {self.input_zp.dtype}, {self.input_zp}}}") + else: + self.input_zp = Parameter(Tensor(x_qparam.zero_point, dtype=zp_dtype)) + logger.debug(f"QuantWithSmoothHighPerformance: input zp of Layer({parallel_type}:{layer_name}) is " + f"{{{self.input_zp.shape}, {self.input_zp.dtype}, {self.input_zp.asnumpy()}}}") + + def construct(self, x): + x = msops.add(x, self.beta_osp) + return self.quant(x, self.input_scale, self.input_zp, False, "ROUND", dtype.int8) + class SmoothMatmul(QuantUnitCell): """SmoothMatmul""" def __init__(self, layer_name, mm, smooth_scale_): @@ -402,7 +558,6 @@ class SmoothMatmulForDeploy(QuantUnitCell): return {} return {self.smooth_scale.name: {'shape': self.smooth_scale.shape, 'shard': smooth_scale_shard}} - class OutlierSuppressionPlusMatmulForDeploy(QuantUnitCell): """OutlierSuppressionPlusMatmulForDeploy""" def __init__(self, layer_name, mm, ic_, compute_dtype_): @@ -430,6 +585,103 @@ class OutlierSuppressionPlusMatmulForDeploy(QuantUnitCell): return {} return {self.beta.name: {'shape': self.beta.shape, 'shard': beta_shard}} +class OutlierSuppressionPlusSmoothMatmul(QuantUnitCell): + """OutlierSuppressionPlusSmoothMatmul""" + def __init__(self, layer_name, mm, smooth_scale_, beta_osp_): + super().__init__(layer_name) + self.mm = mm + self.smooth_scale = Parameter(msops.div(1, smooth_scale_)) + self.beta_osp = Parameter(beta_osp_) + self.is_group_mm = isinstance(mm, GroupedMatmulV4) + logger.debug( + f"OSPSmoothMatmul: smooth_scale for act of Layer({layer_name}) is {{{self.smooth_scale.shape}, " + f"{self.smooth_scale.dtype}, {self.smooth_scale.asnumpy()}}}" + ) + + def update(self, layer_name, mm, smooth_scale_, beta_): + self.layer_name = layer_name + self.mm = mm + self.smooth_scale.set_data(msops.div(1, smooth_scale_)) + self.beta_.set_data(beta_) + + @staticmethod + def create(layer_name, src, smooth_scale, beta_osp): + if isinstance(src, (msops.MatMul, GroupedMatmulV4)): + return OutlierSuppressionPlusSmoothMatmul(layer_name, src, smooth_scale, beta_osp) + raise ValueError(f"Not support creating OutlierSuppressionPlusSmoothMatmul from {src}.") + + def construct(self, *args, **kwargs): + ''' + adapt operators for beta_osp int osp + ''' + args = list(args) + x = args[0][0] if self.is_group_mm else args[0] + smooth_scale = msops.cast(self.smooth_scale, x.dtype) + beta_osp = msops.cast(self.beta_osp, x.dtype) + x = msops.mul(msops.add(x, beta_osp), smooth_scale) + if self.is_group_mm: + args[0][0] = x + else: + args[0] = x + return self.mm(*args, **kwargs) + + # pylint: disable=arguments-differ + def param_shard_state(self, tensor_parallel_num=1, parallel_type: ParallelType = ParallelType.NO_PARALLEL): + if parallel_type == ParallelType.COL_PARALLEL: + smooth_scale_shard = (1,) + beta_shared = (1,) + elif parallel_type == ParallelType.ROW_PARALLEL: + smooth_scale_shard = (tensor_parallel_num,) + beta_shared = (tensor_parallel_num,) + else: + return {} + shard_state = {self.smooth_scale.name: {'shape': self.smooth_scale.shape, 'shard': smooth_scale_shard}} + shard_state[self.beta_osp.name] = {'shape': self.beta_osp.shape, 'shard': beta_shared} + return shard_state + +class OutlierSuppressionPlusSmoothMatmulForDeploy(QuantUnitCell): + """OutlierSuppressionPlusSmoothMatmulForDeploy""" + def __init__(self, layer_name, mm, ic_, compute_dtype_): + super().__init__(layer_name) + self.mm = mm + self.smooth_scale = Parameter(initializer('ones', (ic_,), dtype=compute_dtype_)) + self.beta_osp = Parameter(initializer('zeros', (ic_,), dtype=compute_dtype_)) + self.is_group_mm = isinstance(mm, GroupedMatmulV4) + + @staticmethod + def create(layer_name, src, ic, compute_dtype): + if isinstance(src, (msops.MatMul, GroupedMatmulV4)): + return OutlierSuppressionPlusSmoothMatmulForDeploy(layer_name, src, ic, compute_dtype) + raise ValueError(f"Not support creating OutlierSuppressionPlusSmoothMatmulForDeploy from {src}.") + + def construct(self, *args, **kwargs): + ''' + adapt operators for beta_osp int osp + ''' + args = list(args) + x = args[0][0] if self.is_group_mm else args[0] + smooth_scale = msops.cast(self.smooth_scale, x.dtype) + beta_osp = msops.cast(self.beta_osp, x.dtype) + x = msops.mul(msops.add(x, beta_osp), smooth_scale) + if self.is_group_mm: + args[0][0] = x + else: + args[0] = x + return self.mm(*args, **kwargs) + + # pylint: disable=arguments-differ + def param_shard_state(self, tensor_parallel_num=1, parallel_type: ParallelType = ParallelType.NO_PARALLEL): + if parallel_type == ParallelType.COL_PARALLEL: + smooth_scale_shard = (1,) + beta_shared = (1,) + elif parallel_type == ParallelType.ROW_PARALLEL: + smooth_scale_shard = (tensor_parallel_num,) + beta_shared = (tensor_parallel_num,) + else: + return {} + shard_state = {self.smooth_scale.name: {'shape': self.smooth_scale.shape, 'shard': smooth_scale_shard}} + shard_state[self.beta_osp.name] = {'shape': self.beta_osp.shape, 'shard': beta_shared} + return shard_state class DynamicQuantMatmul(QuantUnitCell): """dynamic quant""" @@ -847,6 +1099,36 @@ class AllQuantMatmul(QuantUnitCell): return q_correction_t + @staticmethod + def _correction_into_bias_for_osp_into_bias( + linear_mm, + quant_weight: Parameter, + x_qparam: QuantParam, + w_qparam: QuantParam, + trans_b, + new_bias_need_allreduce, + beta_correction, + ) -> Tensor: + """compute fused bias""" + value_check("deploy class name", linear_mm, OutlierSuppressionPlusSmoothMatmul) + if quant_weight is None: + raise ValueError("quant_weight is None.") + x_zp = x_qparam.zero_point.asnumpy() + q_correction = -np.sum(x_zp.astype(np.int32) * quant_weight.asnumpy().astype(np.int32), + axis=1 if trans_b else 0).astype(np.int32) + dequant_scale = Tensor(np.squeeze(x_qparam.scale.asnumpy() * w_qparam.scale.asnumpy()).astype(np.float64)) + beta_correction_s = beta_correction / dequant_scale + beta_correction = msops.round(beta_correction_s).astype("int32") + # add warning for reduce + mean_ratio = msops.ReduceMean()(beta_correction_s / (beta_correction+1e-6)) + if mean_ratio > 1.1: + warn_str = "Bias for osp error is too big!" + warnings.warn(warn_str, RuntimeWarning) + correction_t = Tensor(q_correction) + beta_correction + if new_bias_need_allreduce: + correction_t = msops.AllReduce(op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP)(correction_t) + return correction_t + @staticmethod def _from_matmul_prim(layer_name, mm, x_qparam: QuantParam, w_qparam: QuantParam, is_deploy, transpose_a=False, transpose_b=False, quant_bias=None, dst_dtype=dtype.float16, @@ -897,6 +1179,72 @@ class AllQuantMatmul(QuantUnitCell): f'matmul of SmoothMatmulForDeploy should be an instance of {msops.MatMul} or {GroupedMatmulV4}, ' f'but got {src.mm}.') + @staticmethod + def _from_outlier_suppression_plus_smooth_matmul( + layer_name, + src: OutlierSuppressionPlusSmoothMatmul, + x_qparam: QuantParam, + w_qparam: QuantParam, + is_deploy, + transpose_a=False, + transpose_b=False, + quant_bias=None, + dst_dtype=dtype.float16, + kernel_type: KernelType = KernelType.ASD, + ): + """from outlier suppression plus smooth""" + if isinstance(src.mm, (msops.MatMul, GroupedMatmulV4)): + qmm = AllQuantMatmul.create_self( + layer_name, + src.mm, + is_deploy, + x_qparam, + w_qparam, + transpose_a, + transpose_b, + quant_bias, + dst_dtype, + kernel_type, + ) + return qmm, src.smooth_scale, src.beta_osp + raise ValueError( + f"matmul of OSPSmoothMatmul should be an instance of {msops.MatMul} or {GroupedMatmulV4}, " + f"but got {src.mm}." + ) + + @staticmethod + def _from_outlier_suppression_plus_smooth_matmul_for_deploy( + layer_name, + src: OutlierSuppressionPlusSmoothMatmulForDeploy, + x_qparam: QuantParam, + w_qparam: QuantParam, + is_deploy, + transpose_a=False, + transpose_b=False, + quant_bias=None, + dst_dtype=dtype.float16, + kernel_type: KernelType = KernelType.ASD, + ): + """from outlier suppression plus smooth for deploy""" + if isinstance(src.mm, (msops.MatMul, GroupedMatmulV4)): + qmm = AllQuantMatmul.create_self( + layer_name, + src.mm, + is_deploy, + x_qparam, + w_qparam, + transpose_a, + transpose_b, + quant_bias, + dst_dtype, + kernel_type, + ) + return qmm, src.smooth_scale, src.beta_osp + raise ValueError( + f"matmul of OSPSmoothMatmulForDeploy should be an instance of {msops.MatMul} or {GroupedMatmulV4}, " + f"but got {src.mm}." + ) + @staticmethod def create_self(layer_name, mm, is_deploy, x_qparam: QuantParam, w_qparam: QuantParam, transpose_a=False, transpose_b=False, quant_bias=None, dst_dtype=dtype.float16, @@ -911,10 +1259,23 @@ class AllQuantMatmul(QuantUnitCell): raise RuntimeError(f"Not supported kernel type: {kernel_type}") @staticmethod - def create(layer_name, linear, parallel_type: ParallelType, q_weight, x_qparam: QuantParam, w_qparam: QuantParam, - is_deploy, tp_size, dst_dtype=dtype.float16, - kernel_type: KernelType = KernelType.ASD) -> (QuantWithSmooth, 'AllQuantMatmul', Parameter): - """create""" + def create( + layer_name, + linear, + parallel_type: ParallelType, + q_weight, + x_qparam: QuantParam, + w_qparam: QuantParam, + is_deploy, + tp_size, + dst_dtype=dtype.float16, + kernel_type: KernelType = KernelType.ASD, + bias_osp=None, + ) -> (QuantWithSmooth, "AllQuantMatmul", Parameter): + "AllQuantMatmul create" + is_osp = isinstance(linear.matmul, OutlierSuppressionPlusSmoothMatmul) + if bias_osp is None and is_osp: + raise ValueError("Use OSP but bias_osp is None.") trans_a = False trans_b = linear.transpose_b rank = len(q_weight.shape) @@ -930,10 +1291,23 @@ class AllQuantMatmul(QuantUnitCell): else: # fuse bias if parallel_type is ParallelType.ROW_PARALLEL: - t_bias = AllQuantMatmul._correction_into_bias(q_weight, x_qparam, w_qparam, trans_b, tp_size > 1, - dst_dtype) + if not is_osp: + t_bias = AllQuantMatmul._correction_into_bias( + q_weight, x_qparam, w_qparam, trans_b, tp_size > 1, dst_dtype + ) + else: + t_bias = AllQuantMatmul._correction_into_bias_for_osp_into_bias( + linear.matmul, q_weight, x_qparam, w_qparam, trans_b, tp_size > 1, bias_osp + ) else: - t_bias = AllQuantMatmul._correction_into_bias(q_weight, x_qparam, w_qparam, trans_b, False, dst_dtype) + if not is_osp: + t_bias = AllQuantMatmul._correction_into_bias( + q_weight, x_qparam, w_qparam, trans_b, False, dst_dtype + ) + else: + t_bias = AllQuantMatmul._correction_into_bias_for_osp_into_bias( + linear.matmul, q_weight, x_qparam, w_qparam, trans_b, False, bias_osp + ) quant_bias = Parameter(t_bias, name=bias_name) # create qmm @@ -962,6 +1336,39 @@ class AllQuantMatmul(QuantUnitCell): dst_dtype, kernel_type) quant = QuantWithOutlierSuppressionPlus.create(layer_name, x_qparam, ic, dst_dtype, is_deploy, parallel_type, beta, kernel_type) + # for osp of omniquant + elif isinstance(linear.matmul, OutlierSuppressionPlusSmoothMatmul): + qmm, smooth_scale, beta_osp = AllQuantMatmul._from_outlier_suppression_plus_smooth_matmul( + layer_name, + linear.matmul, + x_qparam, + w_qparam, + is_deploy, + trans_a, + trans_b, + quant_bias, + dst_dtype, + kernel_type, + ) + quant = QuantWithOutlierSuppressionPlusSmooth.create( + layer_name, x_qparam, ic, dst_dtype, is_deploy, parallel_type, smooth_scale, beta_osp, kernel_type + ) + elif isinstance(linear.matmul, OutlierSuppressionPlusSmoothMatmulForDeploy): + qmm, smooth_scale, beta_osp = AllQuantMatmul._from_outlier_suppression_plus_smooth_matmul_for_deploy( + layer_name, + linear.matmul, + x_qparam, + w_qparam, + is_deploy, + trans_a, + trans_b, + quant_bias, + dst_dtype, + kernel_type, + ) + quant = QuantWithOutlierSuppressionPlusSmooth.create( + layer_name, x_qparam, ic, dst_dtype, is_deploy, parallel_type, smooth_scale, beta_osp, kernel_type + ) else: raise ValueError(f"Not support creating AllQuantMatmul from {linear.matmul}.") return quant, qmm diff --git a/mindspore_gs/ptq/ptq/wrappers/mindformers/__init__.py b/mindspore_gs/ptq/ptq/wrappers/mindformers/__init__.py index eb52cace..59d7208d 100644 --- a/mindspore_gs/ptq/ptq/wrappers/mindformers/__init__.py +++ b/mindspore_gs/ptq/ptq/wrappers/mindformers/__init__.py @@ -14,7 +14,7 @@ # ============================================================================ """Wrapper cells for PTQ for MindFormers.""" -from .linear_smooth_wrappers import SmoothQuantLinearCell, AWQSmoothLinearCell, OutlierSuppressionPlusLinearCell, SearchOutlierSuppressionLiteLinearCell +from .linear_smooth_wrappers import SmoothQuantLinearCell, AWQSmoothLinearCell, OutlierSuppressionPlusLinearCell, OutlierSuppressionPlusSmoothLinearCell, SearchOutlierSuppressionLiteLinearCell from .linear_weight_quant_wrappers import WeightQuantLinearCell from .linear_gptq_quant_wrappers import GptqWeightQuantLinearCell from .linear_clip_wrappers import ClipLinearCell @@ -27,6 +27,7 @@ SearchOutlierSuppressionLiteLinearCell.reg_self() SmoothQuantLinearCell.reg_self() AWQSmoothLinearCell.reg_self() OutlierSuppressionPlusLinearCell.reg_self() +OutlierSuppressionPlusSmoothLinearCell.reg_self() WeightQuantLinearCell.reg_self() GptqWeightQuantLinearCell.reg_self() ClipLinearCell.reg_self() diff --git a/mindspore_gs/ptq/ptq/wrappers/mindformers/linear_all_quant_wrappers.py b/mindspore_gs/ptq/ptq/wrappers/mindformers/linear_all_quant_wrappers.py index f1997dd4..222eddd6 100644 --- a/mindspore_gs/ptq/ptq/wrappers/mindformers/linear_all_quant_wrappers.py +++ b/mindspore_gs/ptq/ptq/wrappers/mindformers/linear_all_quant_wrappers.py @@ -16,7 +16,7 @@ from mindspore import Parameter, Tensor, dtype from mindspore.common.initializer import initializer - +from mindspore import ops as msops from mindformers.modules.layers import Linear from mindformers.experimental.infer.core.layers import RowParallelLinear, ColumnParallelLinear @@ -24,7 +24,7 @@ from mindspore_gs.ptq.ptq_config import PTQMode, QuantGranularity from mindspore_gs.ptq.context import InnerPTQConfig from mindspore_gs.ptq.basic_quant_func import quant_tensor from mindspore_gs.common import logger -from mindspore_gs.ptq.ptq.hal import QuantParam, AllQuantMatmul, ParallelType, KernelType +from mindspore_gs.ptq.ptq.hal import QuantParam, AllQuantMatmul, ParallelType, KernelType, OutlierSuppressionPlusSmoothMatmul from mindspore_gs.ptq.ptq.algorithms.quantizer import Quantizer from mindspore_gs.ptq.ptq.wrapper_cell import Checker from .parallel_minmax import get_min_max_op @@ -92,9 +92,21 @@ class AllQuantLinearInferCell(LinearInferCell): self.cfg = cfg is_deploy = cfg.mode == PTQMode.DEPLOY use_aclnn_quant = any(opname in layer_name for opname in cfg.aclnn_quant_list) + bias_osp = None + if isinstance(self.layer.matmul, OutlierSuppressionPlusSmoothMatmul): + origin_weight = msops.mul(self._layer.weight, linear.matmul.smooth_scale) + bias_osp = msops.matmul( + msops.expand_dims(-linear.matmul.beta_osp, 0), + ( + origin_weight.astype("float32").transpose() + if self._layer.transpose_b + else self._layer.weight.astype("float32") + ), + ) + bias_osp = bias_osp.squeeze() quant, qmm = AllQuantMatmul.create(layer_name, linear, parallel_type, q_weight, x_qparam, w_qparam, is_deploy, cfg.tp_size, compute_type, - KernelType.ACLNN if use_aclnn_quant else KernelType.INTERNAL) + KernelType.ACLNN if use_aclnn_quant else KernelType.INTERNAL, bias_osp) if not is_deploy: logger.debug(f"AllQuantLinearInferCell: x_qparam of Layer({parallel_type}:{layer_name}) is {x_qparam}") logger.debug(f"AllQuantLinearInferCell: w_qparam of Layer({parallel_type}:{layer_name}) is {w_qparam}") diff --git a/mindspore_gs/ptq/ptq/wrappers/mindformers/linear_smooth_wrappers.py b/mindspore_gs/ptq/ptq/wrappers/mindformers/linear_smooth_wrappers.py index d00c467e..3911df8b 100644 --- a/mindspore_gs/ptq/ptq/wrappers/mindformers/linear_smooth_wrappers.py +++ b/mindspore_gs/ptq/ptq/wrappers/mindformers/linear_smooth_wrappers.py @@ -22,18 +22,19 @@ from types import MethodType from typing import Optional import numpy as np -from mindspore import Tensor, nn +from mindspore.common.initializer import initializer +from mindspore import Tensor, nn, Parameter from mindspore import ops as msops from mindspore import dtype as msdtype -from mindspore.communication.management import GlobalComm from mindspore.ops.operations.comm_ops import ReduceOp +from mindspore.communication.management import GlobalComm from mindformers.modules.layers import Linear from mindformers.experimental.infer.core.layers import ColumnParallelLinear, RowParallelLinear from mindspore_gs.common import logger from mindspore_gs.common.json_cache import JSONCache from mindspore_gs.ptq.ptq_config import PTQMode, OutliersSuppressionType, QuantGranularity from mindspore_gs.ptq.context import InnerPTQConfig -from mindspore_gs.ptq.ptq.hal import SmoothMatmul, SmoothMatmulForDeploy, OutlierSuppressionPlusMatmulForDeploy +from mindspore_gs.ptq.ptq.hal import SmoothMatmul, SmoothMatmulForDeploy, OutlierSuppressionPlusMatmulForDeploy, OutlierSuppressionPlusSmoothMatmulForDeploy, OutlierSuppressionPlusSmoothMatmul from mindspore_gs.ptq.ptq.algorithms.anti_outliers import LinearSmoothQuant, LinearAutoSmoother from mindspore_gs.ptq.ptq.wrapper_cell import Checker, SearchInputs from mindspore_gs.ptq.basic_quant_func import quant_tensor @@ -49,7 +50,7 @@ class SmoothMethod(enum.Enum): SMOOTH_QUANT = 1 AWQ = 2 AUTO = 3 - OSP = 4 + OUTLIER_SUPPRESSION_PLUS = 4 class SmoothLinearCell(WrapperLinearCell): @@ -73,20 +74,26 @@ class SmoothLinearCell(WrapperLinearCell): def _get_smooth_method(self): raise NotImplementedError - def _calc_smooth_scale(self, alpha): + def _calc_smooth_scale(self, alpha, shift_values=None): if self.smooth_method is SmoothMethod.SMOOTH_QUANT: - return self._calc_smooth_quant_smooth_scale(alpha) + return self._calc_smooth_quant_smooth_scale(alpha, shift_values) if self.smooth_method is SmoothMethod.AWQ: return self._calc_awq_smooth_scale(alpha) raise RuntimeError(f"Unsupported SmoothMethod: {self.smooth_method}") - def _calc_smooth_quant_smooth_scale(self, alpha): + def _calc_smooth_quant_smooth_scale(self, alpha, shift_values=None): """_calc_smooth_scale""" self.cfg.dumper.dump_data(self.layer_name, "|smooth_scale|activation_minmax|input0_alpha", Tensor(alpha)) self.cfg.dumper.dump_data(self.layer_name, "|smooth_scale|activation_minmax|input1_activation_inputs", self.cat_samples) - act_max = msops.maximum(msops.abs(self.x_obs_max(self.cat_samples, 0)[0]), - msops.abs(self.x_obs_min(self.cat_samples, 0)[0])) + act_max = msops.maximum( + msops.abs( + self.x_obs_max(self.cat_samples - shift_values if shift_values is not None else self.cat_samples, 0)[0] + ), + msops.abs( + self.x_obs_min(self.cat_samples - shift_values if shift_values is not None else self.cat_samples, 0)[0] + ), + ) logger.debug(f"SmoothLinearCell: act_max of Layer({self._layer_name}) is {{{act_max.shape}, {act_max.dtype}}}") input_max_pow = msops.pow(act_max, alpha) self.cfg.dumper.dump_data(self.layer_name, "|smooth_scale|activation_minmax|output0_activation_minmax_pow", @@ -228,7 +235,7 @@ class SmoothLinearCell(WrapperLinearCell): return state_dict @staticmethod - #pylint: disable=W0211 + # pylint: disable=W0211 def row_sharded_state_dict(self): """provide the sharded state dict based on the config""" w_shard = (1, self.tensor_parallel_group_size) if self.transpose_b else (self.tensor_parallel_group_size, 1) @@ -554,7 +561,6 @@ class SearchOutlierSuppressionLiteLinearCell(SmoothQuantLinearCell): self.cat_samples = None self.samples.clear() - class AWQSmoothLinearCell(AWQLinearCell): """AWQLinearCell""" @@ -784,6 +790,221 @@ class AWQSmoothLinearCell(AWQLinearCell): gc.collect() +class OutlierSuppressionPlusSmoothLinearCell(SearchOutlierSuppressionLiteLinearCell): + """OutlierSuppressionPlusSmoothLinearCell""" + + @staticmethod + def reg_self(): + class OutlierSuppressionPlusSmoothChecker(Checker): + def check(self, config: InnerPTQConfig): + return config.outliers_suppression == OutliersSuppressionType.OUTLIER_SUPPRESSION_PLUS and \ + config.use_inner_osp + + LinearAutoSmoother.reg_layer_map( + Linear, OutlierSuppressionPlusSmoothLinearCell, OutlierSuppressionPlusSmoothChecker() + ) + LinearAutoSmoother.reg_layer_map(ColumnParallelLinear, + OutlierSuppressionPlusSmoothLinearCell, + OutlierSuppressionPlusSmoothChecker()) + LinearAutoSmoother.reg_layer_map(RowParallelLinear, + OutlierSuppressionPlusSmoothLinearCell, + OutlierSuppressionPlusSmoothChecker()) + + def __init__(self, linear_name, linear, context, cfg, network_helper, **kwargs): + super().__init__(linear_name, linear, context, cfg, network_helper, **kwargs) + if linear.expert_num and linear.expert_num > 1: + self.is_moe = True + else: + self.is_moe = False + self.linear_name = linear_name.split('.')[-1] if not self.is_moe else 'moe|' + linear_name.split('.')[-1] + self.ic = linear.weight.shape[self.ic_axis] + self.shift_values = Parameter(initializer('ones', (1, self.ic), dtype=msdtype.float32)) + self.all_bias = None + self.osp_bias_t = None + self.is_replaced = False + if context.algorithm_cache_path: + self.enable_cache = True + else: + self.enable_cache = False + + def _quant_info(self): + return "OutlierSuppressionPlus" + + def _get_smooth_method(self): + return SmoothMethod.SMOOTH_QUANT + + def construct(self, x, *args, **kwargs): + if self.quant_forward: + x = (x - self.shift_values) * self.x_scale_fast + self.x_zp + x = msops.round(x) + x = msops.clip(x, -128., 127.) + self._layer.compute_dtype = msdtype.float32 + y = self._layer(x, *args, **kwargs) + self._layer.compute_dtype = self.compute_type + if self.is_replaced: + y = y + self.osp_bias_t + if self.quant_forward: + y = y * self.deq_scale + self.all_bias + y = msops.cast(y, self.compute_type) + return y + + def get_shift_values_for_channel(self, x): + max_values = msops.ReduceMax()(x, axis=0) + min_values = msops.ReduceMin()(x, axis=0) + return (max_values + min_values) / 2 + + def _cache_key(self): + best_alpha_name = '|' + 'layers|' + self.layer_name.split('.')[3] + '|' + self.linear_name + '|best_alpha' + shift_values_name = '|' + 'layers|' + self.layer_name.split('.')[3] + '|' + self.linear_name + '|shift_values' + return best_alpha_name, shift_values_name + + def _search_best_scale(self, alpha): + """search best scale""" + best_alpha_name, shift_values_name = self._cache_key() + best_alpha = self.cache.get(best_alpha_name) + shift_values = self.cache.get(shift_values_name) + + if best_alpha is not None and shift_values is not None: + logger.info(f'layer {self.layer_name} using cached alpha: {best_alpha} and shift values.') + self.shift_values.set_data(Tensor(shift_values, dtype=msdtype.float32)) + best_scale = self._calc_smooth_scale(best_alpha, self.shift_values) + logger.info( + f"OSPSmoothLinearCell: best scale alpha {best_alpha}, best_scale of Layer({self._layer_name})" + ) + else: + best_scale, best_alpha = self._compute_best_scale(alpha) + if self.enable_cache: + self.cache.put(best_alpha_name, best_alpha) + self.cache.put(shift_values_name, self.shift_values.asnumpy()) + return best_scale + + def _compute_best_scale(self, alpha): + """compute best scale""" + history = [] + best_ratio = -1 + best_scale = 0 + best_error = float("inf") + fp16_weight = copy.deepcopy(self._layer.weight).astype(self.compute_type) + # calculate shift_values for cat_samples + self.shift_values = self.get_shift_values_for_channel(self.cat_samples) + # beta * origin_weight + self.osp_bias_t = msops.matmul( + msops.expand_dims(self.shift_values, 0), + ( + self._layer.weight.astype("float32").transpose() + if self._layer.transpose_b + else self._layer.weight.astype("float32") + ), + ) + if isinstance(self._layer, RowParallelLinear): + self.osp_bias_t = msops.AllReduce(op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP)(self.osp_bias_t) + + group_size = self.cfg.group_size if self.cfg.weight_quant_granularity == QuantGranularity.PER_GROUP \ + else self._layer.weight.shape[self.ic_axis] + fp16_output = self._module_forward(False) + for ratio in alpha: + scales = self._calc_smooth_scale(ratio, self.shift_values) + self._apply_weight_smooth(scales) + xs = (self.cat_samples - self.shift_values) / scales + x_scale, x_zp, _ = quant_tensor(xs, + self.x_quant_min, + self.x_quant_max, + self.cfg.act_narrow_range, + self.cfg.act_symmetric, + False, + group_size, + self.cfg.act_quant_dtype, + -1, + False, + False) + w_scale, _, q_weight = quant_tensor(self._layer.weight.data, + self.w_quant_min, + self.w_quant_max, + self.cfg.weight_narrow_range, + self.cfg.weight_symmetric, + False, + group_size, + self.cfg.weight_quant_dtype, + self.oc_axis, + True, + False) + t_w_scale = Tensor(w_scale) + if self._layer.transpose_b: + t_w_scale = msops.transpose(t_w_scale, (1, 0)) + self.x_scale_fast = Tensor(x_scale) + self.deq_scale = msops.cast((self.x_scale_fast * t_w_scale), msdtype.float32) + self.x_scale_fast = msops.cast(1 / (self.x_scale_fast * Tensor(scales)), msdtype.float32) + self.x_zp = Tensor(x_zp) + + logger.debug( + f"SearchOSPSmoothLinearCell: search scale alpha {ratio}, pesudo weight of Layer({self._layer_name})" + ) + self._layer.weight.set_data(msops.cast(q_weight, self.compute_type)) + y_zp = q_weight.sum(axis=self.ic_axis, dtype=msdtype.int32) * self.x_zp.astype(msdtype.int32) + y_zp_with_deq = y_zp * self.deq_scale + if isinstance(self._layer, RowParallelLinear): + y_zp_with_deq = msops.AllReduce(op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP)(y_zp_with_deq) + self.all_bias = -y_zp_with_deq + self.osp_bias_t + quant_output = self._module_forward(True) + msops.assign(self._layer.weight, fp16_weight.astype(self.compute_type)) + + loss = self._loss(fp16_output, quant_output) + logger.info( + f"SearchOSPSmoothLinearCell: search alpha {ratio}, loss of Layer({self._layer_name}) is {loss}" + ) + history.append(loss) + if loss < best_error: + best_error = loss + best_ratio = ratio + best_scale = scales + self.x_scale_fast = None + self.deq_scale = None + self.x_zp = None + self.y_zp = None + self.all_bias = None + gc.collect() + + del fp16_weight + del fp16_output + del scales + del xs + del x_scale + del x_zp + del w_scale + del q_weight + del t_w_scale + del quant_output + gc.collect() + if best_ratio == -1: + raise RuntimeError(f"Found no suitablt ratio, please check history of loss: {history}.") + logger.info(f"SearchOSPSmoothLinearCell: best scale alpha {best_ratio},best_error of Layer({self._layer_name})" + f"is {best_error}") + return best_scale, best_ratio + + def _apply_act_smooth(self, smooth_scale: Tensor): + """_apply_act_smooth""" + if isinstance(self._layer.matmul, OutlierSuppressionPlusSmoothMatmul): + self._layer.matmul.update(self._layer_name, self._layer.matmul.mm, smooth_scale, -self.shift_values) + else: + self._layer.matmul = OutlierSuppressionPlusSmoothMatmul.create( + self._layer_name, self._layer.matmul, smooth_scale=smooth_scale, beta_osp=-self.shift_values + ) + self.is_replaced = True + + def _apply_act_smooth_for_deploy(self, ic, compute_dtype): + """_apply_act_smooth_by_insert_op_for_deploy""" + self._layer.matmul = OutlierSuppressionPlusSmoothMatmulForDeploy.create( + self._layer_name, self._layer.matmul, ic=ic, compute_dtype=compute_dtype + ) + + def smooth(self): + """smooth""" + smooth_alpha = [i/20 for i in range(21)] + smooth_scale = self._search_best_scale(smooth_alpha) + xs = self.cat_samples * smooth_scale + self.check_xrange(self.cat_samples, xs) + self._apply_smooth(smooth_scale) + class OutlierSuppressionPlusLinearCell(AWQSmoothLinearCell): """SmoothQuantPlusLinearCell""" @@ -791,7 +1012,8 @@ class OutlierSuppressionPlusLinearCell(AWQSmoothLinearCell): def reg_self(): class OutlierSuppressionPlusChecker(Checker): def check(self, config: InnerPTQConfig): - return config.outliers_suppression == OutliersSuppressionType.OUTLIER_SUPPRESSION_PLUS + return config.outliers_suppression == OutliersSuppressionType.OUTLIER_SUPPRESSION_PLUS and \ + not config.use_inner_osp LinearAutoSmoother.reg_layer_map(Linear, OutlierSuppressionPlusLinearCell, OutlierSuppressionPlusChecker()) LinearAutoSmoother.reg_layer_map(ColumnParallelLinear, @@ -804,7 +1026,7 @@ class OutlierSuppressionPlusLinearCell(AWQSmoothLinearCell): return "OutlierSuppressionPlus" def _get_smooth_method(self): - return SmoothMethod.OSP + return SmoothMethod.OUTLIER_SUPPRESSION_PLUS def _apply_act_smooth_for_deploy(self, ic, compute_dtype): """_apply_act_smooth_by_insert_op_for_deploy""" diff --git a/tests/st/ptq/ptq/ptq_network_runner.py b/tests/st/ptq/ptq/ptq_network_runner.py index 2d8b359f..f7b54e6a 100644 --- a/tests/st/ptq/ptq/ptq_network_runner.py +++ b/tests/st/ptq/ptq/ptq_network_runner.py @@ -178,6 +178,16 @@ def create_cfg(quant_algo_, mode): act_quant_granularity=QuantGranularity.PER_TENSOR, weight_quant_granularity=QuantGranularity.PER_CHANNEL, kvcache_quant_granularity=QuantGranularity.PER_CHANNEL) + elif quant_algo_ == 'OSPQuant_A8W8': + os.environ["MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST"] = "RmsNormQuant,PagedAttention" + cfg = PTQConfig(mode=mode, + backend=BackendTarget.ASCEND, + opname_blacklist=["w2", "lm_head"], + weight_quant_dtype=dtype.int8, act_quant_dtype=dtype.int8, + outliers_suppression=OutliersSuppressionType.OUTLIER_SUPPRESSION_PLUS, + act_quant_granularity=QuantGranularity.PER_TENSOR, + weight_quant_granularity=QuantGranularity.PER_CHANNEL, + kvcache_quant_granularity=QuantGranularity.PER_CHANNEL) else: raise ValueError(f"Unsupported quant_algo : {quant_algo_}") return cfg @@ -225,6 +235,9 @@ def quant_llama2(config_path_, ckpt_path, output_dir_, example, quant_algo_): if quant_algo_ == "A16W4_AWQ": # pylint: disable=W0212 ptq._config.weight_symmetric = False + if quant_algo_ == "OSPQuant_A8W8": + # pylint: disable=W0212 + ptq._config.use_inner_osp = True # pylint: disable=W0212 ptq._config.enable_deploy_fusion = False ds = create_hello_ds(tokenizer, 1) @@ -268,6 +281,9 @@ def eval_llama2(input_, is_quant, config_path_, ckpt_path_, quant_algo_): if quant_algo_ == "A8W8_FallBack": # pylint: disable=W0212 ptq._config.fallback_blacklist = {'w2': LayerQuantizeAlgo.A16W8} + if quant_algo_ == "OSPQuant_A8W8": + # pylint: disable=W0212 + ptq._config.use_inner_osp = True # pylint: disable=W0212 ptq._config.enable_deploy_fusion = False network = ptq.apply(network) @@ -324,6 +340,8 @@ def ptq_llama2_predict_2stage(config_path_, fp16_ckpt_path_, quant_ckpt_path_, o ret = np.allclose(qoutput[:, :100], foutput[:, :100], 0, 0) elif quant_algo_ == 'OSL_A8W8': ret = np.allclose(qoutput[:, :3], foutput[:, :3], 0, 0) + elif quant_algo_ == 'OSPQuant_A8W8': + ret = np.allclose(qoutput[:, :3], foutput[:, :3], 0, 0) else: assert False if not ret: @@ -355,6 +373,8 @@ def ptq_llama2_predict_2stage(config_path_, fp16_ckpt_path_, quant_ckpt_path_, o ret = np.allclose(qoutput[:, :100], foutput[:, :100], 0, 0) elif quant_algo_ == 'OSL_A8W8': ret = np.allclose(qoutput[:, :3], foutput[:, :3], 0, 0) + elif quant_algo_ == 'OSPQuant_A8W8': + ret = np.allclose(qoutput[:, :3], foutput[:, :3], 0, 0) else: assert False if not ret: diff --git a/tests/st/ptq/ptq/test_ptq.py b/tests/st/ptq/ptq/test_ptq.py index 85b949c8..51f46fe8 100644 --- a/tests/st/ptq/ptq/test_ptq.py +++ b/tests/st/ptq/ptq/test_ptq.py @@ -840,7 +840,7 @@ def test_ptq_llama2_predict_2stage_1p_run_per_group(quant_algo): @pytest.mark.platform_arm_ascend910b_training @pytest.mark.env_single # 'A16W8C8' -@pytest.mark.parametrize("quant_algo", ['A8W8', 'A16W8', 'A8W8C8', 'OSL_A8W8', 'A8W4_GPTQ']) +@pytest.mark.parametrize("quant_algo", ['A8W8', 'A16W8', 'A8W8C8', 'OSL_A8W8', 'OSPQuant_A8W8', 'A8W4_GPTQ']) def test_ptq_llama2_predict_2stage_2p_run_part1(quant_algo): """ Feature: test PTQ adjust parameter in two stages with two cards. -- Gitee