From 310ba548b15c59d77ea6edf919091c48ff012cf6 Mon Sep 17 00:00:00 2001 From: huangxiaolan Date: Tue, 23 Sep 2025 10:56:52 +0800 Subject: [PATCH 1/4] fix enum class --- src/kernels/include/asdops/params/activation.h | 2 +- src/kernels/include/asdops/params/elewise.h | 2 +- src/kernels/include/asdops/params/faUpdate.h | 2 +- src/kernels/include/asdops/params/index.h | 2 +- src/kernels/include/asdops/params/norm.h | 2 +- src/kernels/include/asdops/params/reduce.h | 2 +- .../include/asdops/params/scatter_elements_v2.h | 2 +- src/kernels/include/asdops/params/transdata.h | 4 ++-- src/kernels/include/atbops/params/blockcopy.h | 4 ++-- .../atbops/params/gmm_deq_swiglu_quant_gmm_deq.h | 12 ++++++------ src/kernels/include/atbops/params/mla.h | 10 +++++----- .../atbops/params/mm_deq_swiglu_quant_mm_deq.h | 8 ++++---- src/kernels/include/atbops/params/paged_cache_load.h | 4 ++-- src/kernels/include/atbops/params/ring_mla.h | 10 +++++----- .../include/atbops/params/unpad_flash_attention_nz.h | 4 ++-- .../gmm_deq_swiglu_quant_gmm_deq_tiling_data.h | 12 ++++++------ .../tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h | 12 ++++++------ 17 files changed, 47 insertions(+), 47 deletions(-) diff --git a/src/kernels/include/asdops/params/activation.h b/src/kernels/include/asdops/params/activation.h index 5bf22fd4..e861be22 100644 --- a/src/kernels/include/asdops/params/activation.h +++ b/src/kernels/include/asdops/params/activation.h @@ -16,7 +16,7 @@ namespace AsdOps { namespace OpParam { struct Activation { - enum ActivationType { + enum class ActivationType { ACTIVATION_UNDEFINED = 0, ACTIVATION_RELU, ACTIVATION_GELU, diff --git a/src/kernels/include/asdops/params/elewise.h b/src/kernels/include/asdops/params/elewise.h index 86e1b7c0..e9a4631f 100644 --- a/src/kernels/include/asdops/params/elewise.h +++ b/src/kernels/include/asdops/params/elewise.h @@ -16,7 +16,7 @@ namespace AsdOps { namespace OpParam { struct Elewise { - enum ElewiseType { + enum class ElewiseType { ELEWISE_UNDEFINED = 0, ELEWISE_CAST, ELEWISE_MULS, diff --git a/src/kernels/include/asdops/params/faUpdate.h b/src/kernels/include/asdops/params/faUpdate.h index 370fd34f..88f8ec1c 100644 --- a/src/kernels/include/asdops/params/faUpdate.h +++ b/src/kernels/include/asdops/params/faUpdate.h @@ -17,7 +17,7 @@ namespace AsdOps { namespace OpParam { struct FaUpdate { - enum FaUpdateType { + enum class FaUpdateType { DECODE_UPDATE = 0, }; FaUpdateType faUpdateType; diff --git a/src/kernels/include/asdops/params/index.h b/src/kernels/include/asdops/params/index.h index ad9f6830..0beff7fc 100644 --- a/src/kernels/include/asdops/params/index.h +++ b/src/kernels/include/asdops/params/index.h @@ -16,7 +16,7 @@ namespace AsdOps { namespace OpParam { struct Index { - enum IndexType { + enum class IndexType { INDEX_UNDEFINED = 0, INDEX_ADD, INDEX_ADD_VALID, diff --git a/src/kernels/include/asdops/params/norm.h b/src/kernels/include/asdops/params/norm.h index 62fd7c11..3a618493 100644 --- a/src/kernels/include/asdops/params/norm.h +++ b/src/kernels/include/asdops/params/norm.h @@ -18,7 +18,7 @@ namespace AsdOps { namespace OpParam { struct Norm { - enum NormType { + enum class NormType { NORM_UNDEFINED = 0, LAYER_NORM, RMS_NORM, diff --git a/src/kernels/include/asdops/params/reduce.h b/src/kernels/include/asdops/params/reduce.h index cfa08787..b7234c87 100644 --- a/src/kernels/include/asdops/params/reduce.h +++ b/src/kernels/include/asdops/params/reduce.h @@ -17,7 +17,7 @@ namespace AsdOps { namespace OpParam { struct Reduce { - enum ReduceType { + enum class ReduceType { REDUCE_UNDEFINED = 0, REDUCE_MAX, REDUCE_MIN, diff --git a/src/kernels/include/asdops/params/scatter_elements_v2.h b/src/kernels/include/asdops/params/scatter_elements_v2.h index 5a91b870..30ac6966 100644 --- a/src/kernels/include/asdops/params/scatter_elements_v2.h +++ b/src/kernels/include/asdops/params/scatter_elements_v2.h @@ -21,7 +21,7 @@ namespace OpParam { * @brief ScatterElementsV2 算子的参数结构体 */ struct ScatterElementsV2 { - enum ReductionType { + enum class ReductionType { NONE = 0, ADD, }; diff --git a/src/kernels/include/asdops/params/transdata.h b/src/kernels/include/asdops/params/transdata.h index 8e8fa15c..9a441b68 100644 --- a/src/kernels/include/asdops/params/transdata.h +++ b/src/kernels/include/asdops/params/transdata.h @@ -17,8 +17,8 @@ namespace AsdOps { namespace OpParam { struct Transdata { - enum TransdataType { UNDEFINED = 0, FRACTAL_NZ_TO_ND, ND_TO_FRACTAL_NZ }; - TransdataType transdataType = UNDEFINED; + enum class TransdataType { UNDEFINED = 0, FRACTAL_NZ_TO_ND, ND_TO_FRACTAL_NZ }; + TransdataType transdataType = TransdataType::UNDEFINED; Mki::SVector outCrops = {0, 0}; bool operator==(const Transdata &other) const diff --git a/src/kernels/include/atbops/params/blockcopy.h b/src/kernels/include/atbops/params/blockcopy.h index ba438dba..2cdf6282 100644 --- a/src/kernels/include/atbops/params/blockcopy.h +++ b/src/kernels/include/atbops/params/blockcopy.h @@ -14,11 +14,11 @@ namespace AtbOps { namespace OpParam { struct BlockCopy { - enum Type { + enum class Type { BLOCK_COPY_CACHE_ND = 0, BLOCK_COPY_CACHE_NZ = 1 }; - Type type = BLOCK_COPY_CACHE_ND; + Type type = Type::BLOCK_COPY_CACHE_ND; bool operator==(const BlockCopy &other) const { return this->type == other.type; diff --git a/src/kernels/include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h b/src/kernels/include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h index 4e1e86cc..7e9591ae 100644 --- a/src/kernels/include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h +++ b/src/kernels/include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h @@ -13,27 +13,27 @@ namespace AtbOps { namespace OpParam { struct GmmDeqSwigluQuantGmmDeq { - enum OutputType { + enum class OutputType { OUTPUT_FLOAT16 = 0, OUTPUT_BFLOAT16, OUTPUT_INVALID }; - enum GroupListType { + enum class GroupListType { GROUP_LIST_CUMSUM = 0, GROUP_LIST_SINGLE, GROUP_LIST_INVALID }; - enum WeightUpPermuteType { + enum class WeightUpPermuteType { PERMUTE_N256 = 0, PERMUTE_N128, PERMUTE_INVALID }; - OutputType outputType = OUTPUT_FLOAT16; - GroupListType groupListType = GROUP_LIST_CUMSUM; - WeightUpPermuteType weightUpPermuteType = PERMUTE_N256; + OutputType outputType = OutputType::OUTPUT_FLOAT16; + GroupListType groupListType = GroupListType::GROUP_LIST_CUMSUM; + WeightUpPermuteType weightUpPermuteType = WeightUpPermuteType::PERMUTE_N256; bool transposeWeightUp = false; bool transposeWeightDown = true; diff --git a/src/kernels/include/atbops/params/mla.h b/src/kernels/include/atbops/params/mla.h index 9b9b363e..1b432591 100644 --- a/src/kernels/include/atbops/params/mla.h +++ b/src/kernels/include/atbops/params/mla.h @@ -19,7 +19,7 @@ namespace AtbOps { namespace OpParam { struct MLA { - enum Type { + enum class Type { SPLIT_CACHE = 0, PREFILL_SPLIT_CACHE = 1 }; @@ -31,7 +31,7 @@ struct MLA { std::vector kTensorList; std::vector vTensorList; - enum MaskType { + enum class MaskType { MASK_TYPE_NONE = 0, MASK_TYPE_NORM = 1, MASK_TYPE_ALIBI = 2, @@ -41,15 +41,15 @@ struct MLA { MASK_TYPE_SWA_NORM = 6 }; - MaskType maskType = MASK_TYPE_NONE; + MaskType maskType = MaskType::MASK_TYPE_NONE; - enum QuantType { + enum class QuantType { TYPE_QUANT_UNDEFINED = 0, TYPE_DEQUANT_FUSION, TYPE_QUANT_QKV_OFFLINE, TYPE_QUANT_QKV_ONLINE }; - QuantType quantType = TYPE_QUANT_UNDEFINED; + QuantType quantType = QuantType::TYPE_QUANT_UNDEFINED; std::vector qSeqLen; std::vector kvSeqLen; diff --git a/src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h b/src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h index 9ecb2270..c31a38b9 100644 --- a/src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h +++ b/src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h @@ -13,20 +13,20 @@ namespace AtbOps { namespace OpParam { struct MmDeqSwigluQuantMmDeq { - enum OutputType { + enum class OutputType { OUTPUT_FLOAT16 = 0, OUTPUT_BFLOAT16, OUTPUT_INVALID }; - enum WeightUpPermuteType { + enum class WeightUpPermuteType { PERMUTE_N256 = 0, PERMUTE_N128, PERMUTE_INVALID }; - OutputType outputType = OUTPUT_FLOAT16; - WeightUpPermuteType weightUpPermuteType = PERMUTE_N256; + OutputType outputType = OutputType::OUTPUT_FLOAT16; + WeightUpPermuteType weightUpPermuteType = WeightUpPermuteType::PERMUTE_N256; bool transposeWeightUp = false; bool transposeWeightDown = true; diff --git a/src/kernels/include/atbops/params/paged_cache_load.h b/src/kernels/include/atbops/params/paged_cache_load.h index c6ebed8f..ab259989 100644 --- a/src/kernels/include/atbops/params/paged_cache_load.h +++ b/src/kernels/include/atbops/params/paged_cache_load.h @@ -15,11 +15,11 @@ namespace AtbOps { namespace OpParam { struct PagedCacheLoad { - enum Type { + enum class Type { PAGED_CACHE_LOAD_ND = 0, PAGED_CACHE_LOAD_NZ = 1, }; - Type type = PAGED_CACHE_LOAD_ND; + Type type = Type::PAGED_CACHE_LOAD_ND; bool cuSeqLens = false; bool hasSeqStarts = false; diff --git a/src/kernels/include/atbops/params/ring_mla.h b/src/kernels/include/atbops/params/ring_mla.h index e2dfc51c..9d25c729 100644 --- a/src/kernels/include/atbops/params/ring_mla.h +++ b/src/kernels/include/atbops/params/ring_mla.h @@ -19,7 +19,7 @@ namespace AtbOps { namespace OpParam { struct RINGMLA { - enum Type { + enum class Type { SPLIT_CACHE = 0, PREFILL_SPLIT_CACHE = 1, }; @@ -31,7 +31,7 @@ struct RINGMLA { std::vector kTensorList; std::vector vTensorList; - enum MaskType { + enum class MaskType { MASK_TYPE_NONE = 0, MASK_TYPE_NORM = 1, MASK_TYPE_ALIBI = 2, @@ -39,15 +39,15 @@ struct RINGMLA { MASK_TYPE_MASK_FREE = 4 }; - MaskType maskType = MASK_TYPE_NONE; + MaskType maskType = MaskType::MASK_TYPE_NONE; - enum QuantType { + enum class QuantType { TYPE_QUANT_UNDEFINED = 0, TYPE_DEQUANT_FUSION, TYPE_QUANT_QKV_OFFLINE, TYPE_QUANT_QKV_ONLINE }; - QuantType quantType = TYPE_QUANT_UNDEFINED; + QuantType quantType = QuantType::TYPE_QUANT_UNDEFINED; std::vector qSeqLen; std::vector kvSeqLen; diff --git a/src/kernels/include/atbops/params/unpad_flash_attention_nz.h b/src/kernels/include/atbops/params/unpad_flash_attention_nz.h index aed8b578..fd8e8577 100644 --- a/src/kernels/include/atbops/params/unpad_flash_attention_nz.h +++ b/src/kernels/include/atbops/params/unpad_flash_attention_nz.h @@ -36,12 +36,12 @@ struct UnpadFlashAttentionNz { SCALE_LOGN = 1 }; ScaleType scaleType = SCALE_TOR; - enum PrecType { + enum class PrecType { BMM1_FP16_EXP_FP32 = 0, BMM1_FP32_EXP_FP32 = 1, BMM2_ONLINE_SOFTMAX_FP16 = 4 }; - PrecType precType = BMM1_FP16_EXP_FP32; + PrecType precType = PrecType::BMM1_FP16_EXP_FP32; // DATA enum DataDimOrder : int { TYPE_BSND = 0, diff --git a/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h b/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h index d821ba1d..107ae861 100644 --- a/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h +++ b/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h @@ -22,10 +22,10 @@ constexpr uint32_t CUSTOM_L0C_STAGES = 1; constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true; constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true; -enum PermuteType { - PERMUTE_N256, - PERMUTE_N128, - PERMUTE_INVALID +enum class PermuteType{ + N256, + N128, + INVALID }; template @@ -33,14 +33,14 @@ struct Gmm1TileArgs { }; template <> -struct Gmm1TileArgs { +struct Gmm1TileArgs { static constexpr uint32_t L1M = 128; static constexpr uint32_t L1N = 256; static constexpr uint32_t EPIM = 32; }; template <> -struct Gmm1TileArgs { +struct Gmm1TileArgs { static constexpr uint32_t L1M = 256; static constexpr uint32_t L1N = 128; static constexpr uint32_t EPIM = 64; diff --git a/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h b/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h index 3c22980d..a4d908ab 100644 --- a/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h +++ b/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h @@ -22,10 +22,10 @@ constexpr uint32_t CUSTOM_L0C_STAGES = 1; constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true; constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true; -enum PermuteType { - PERMUTE_N256, - PERMUTE_N128, - PERMUTE_INVALID +enum class PermuteType{ + N256, + N128, + INVALID }; template @@ -33,14 +33,14 @@ struct Mm1TileArgs { }; template <> -struct Mm1TileArgs { +struct Mm1TileArgs { static constexpr uint32_t L1M = 128; static constexpr uint32_t L1N = 256; static constexpr uint32_t EPIM = 32; }; template <> -struct Mm1TileArgs { +struct Mm1TileArgs { static constexpr uint32_t L1M = 256; static constexpr uint32_t L1N = 128; static constexpr uint32_t EPIM = 64; -- Gitee From f08328c5b2ec2d3fc91268e5353fc0f505bd7c9f Mon Sep 17 00:00:00 2001 From: huangxiaolan Date: Thu, 25 Sep 2025 10:33:25 +0800 Subject: [PATCH 2/4] recover enum name --- .../tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h b/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h index a4d908ab..baabe59c 100644 --- a/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h +++ b/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h @@ -23,9 +23,9 @@ constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true; constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true; enum class PermuteType{ - N256, - N128, - INVALID + PERMUTE_N256, + PERMUTE_N128, + PERMUTE_INVALID }; template @@ -33,14 +33,14 @@ struct Mm1TileArgs { }; template <> -struct Mm1TileArgs { +struct Mm1TileArgs { static constexpr uint32_t L1M = 128; static constexpr uint32_t L1N = 256; static constexpr uint32_t EPIM = 32; }; template <> -struct Mm1TileArgs { +struct Mm1TileArgs { static constexpr uint32_t L1M = 256; static constexpr uint32_t L1N = 128; static constexpr uint32_t EPIM = 64; -- Gitee From 174c9446fbb75f1554176554e0b167a6a655e848 Mon Sep 17 00:00:00 2001 From: huangxiaolan Date: Thu, 25 Sep 2025 10:36:33 +0800 Subject: [PATCH 3/4] recover enum name --- .../tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h b/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h index 107ae861..46d519d4 100644 --- a/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h +++ b/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h @@ -23,9 +23,9 @@ constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true; constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true; enum class PermuteType{ - N256, - N128, - INVALID + PERMUTE_N256, + PERMUTE_N128, + PERMUTE_INVALID }; template @@ -33,14 +33,14 @@ struct Gmm1TileArgs { }; template <> -struct Gmm1TileArgs { +struct Gmm1TileArgs { static constexpr uint32_t L1M = 128; static constexpr uint32_t L1N = 256; static constexpr uint32_t EPIM = 32; }; template <> -struct Gmm1TileArgs { +struct Gmm1TileArgs { static constexpr uint32_t L1M = 256; static constexpr uint32_t L1N = 128; static constexpr uint32_t EPIM = 64; -- Gitee From 07d0d40db1d846bc7341191372659d0e11502044 Mon Sep 17 00:00:00 2001 From: huangxiaolan Date: Thu, 25 Sep 2025 11:26:02 +0800 Subject: [PATCH 4/4] add using enum --- src/kernels/include/asdops/params/activation.h | 1 + src/kernels/include/asdops/params/elewise.h | 1 + src/kernels/include/asdops/params/faUpdate.h | 1 + src/kernels/include/asdops/params/index.h | 1 + src/kernels/include/asdops/params/norm.h | 1 + src/kernels/include/asdops/params/reduce.h | 1 + src/kernels/include/asdops/params/scatter_elements_v2.h | 1 + src/kernels/include/asdops/params/transdata.h | 1 + src/kernels/include/atbops/params/blockcopy.h | 1 + .../include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h | 3 +++ src/kernels/include/atbops/params/mla.h | 3 +++ src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h | 2 ++ src/kernels/include/atbops/params/paged_cache_load.h | 1 + src/kernels/include/atbops/params/ring_mla.h | 3 +++ src/kernels/include/atbops/params/unpad_flash_attention_nz.h | 1 + .../tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h | 1 + .../tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h | 1 + 17 files changed, 24 insertions(+) diff --git a/src/kernels/include/asdops/params/activation.h b/src/kernels/include/asdops/params/activation.h index e861be22..664baacd 100644 --- a/src/kernels/include/asdops/params/activation.h +++ b/src/kernels/include/asdops/params/activation.h @@ -28,6 +28,7 @@ struct Activation { ACTIVATION_SIGMOID, ACTIVATION_FASTER_GELU_FORWARD, }; + using enum ActivationType; ActivationType activationType; float scale = 1.0f; // for Swish diff --git a/src/kernels/include/asdops/params/elewise.h b/src/kernels/include/asdops/params/elewise.h index e9a4631f..247f6037 100644 --- a/src/kernels/include/asdops/params/elewise.h +++ b/src/kernels/include/asdops/params/elewise.h @@ -39,6 +39,7 @@ struct Elewise { ELEWISE_DEQUANT_PER_CHANNEL, ELEWISE_DYNAMIC_QUANT, }; + using enum ElewiseType; ElewiseType elewiseType; float varAttr = 0.0f; // MULS diff --git a/src/kernels/include/asdops/params/faUpdate.h b/src/kernels/include/asdops/params/faUpdate.h index 88f8ec1c..4182a01c 100644 --- a/src/kernels/include/asdops/params/faUpdate.h +++ b/src/kernels/include/asdops/params/faUpdate.h @@ -20,6 +20,7 @@ struct FaUpdate { enum class FaUpdateType { DECODE_UPDATE = 0, }; + using enum FaUpdateType; FaUpdateType faUpdateType; uint32_t sp; diff --git a/src/kernels/include/asdops/params/index.h b/src/kernels/include/asdops/params/index.h index 0beff7fc..a747d4b3 100644 --- a/src/kernels/include/asdops/params/index.h +++ b/src/kernels/include/asdops/params/index.h @@ -21,6 +21,7 @@ struct Index { INDEX_ADD, INDEX_ADD_VALID, }; + using enum IndexType; IndexType indexType; int64_t axis = 0; diff --git a/src/kernels/include/asdops/params/norm.h b/src/kernels/include/asdops/params/norm.h index 3a618493..16e81793 100644 --- a/src/kernels/include/asdops/params/norm.h +++ b/src/kernels/include/asdops/params/norm.h @@ -26,6 +26,7 @@ struct Norm { RMS_NORM_BACKWARD, GATHER_PRE_RMS_NORM, }; + using enum NormType; NormType normType; // layernorm int32_t beginNormAxis = 0; diff --git a/src/kernels/include/asdops/params/reduce.h b/src/kernels/include/asdops/params/reduce.h index b7234c87..0d312528 100644 --- a/src/kernels/include/asdops/params/reduce.h +++ b/src/kernels/include/asdops/params/reduce.h @@ -23,6 +23,7 @@ struct Reduce { REDUCE_MIN, REDUCE_SUM, }; + using enum ReduceType; ReduceType reduceType; Mki::SVector axis; diff --git a/src/kernels/include/asdops/params/scatter_elements_v2.h b/src/kernels/include/asdops/params/scatter_elements_v2.h index 30ac6966..f9f441fa 100644 --- a/src/kernels/include/asdops/params/scatter_elements_v2.h +++ b/src/kernels/include/asdops/params/scatter_elements_v2.h @@ -25,6 +25,7 @@ struct ScatterElementsV2 { NONE = 0, ADD, }; + using enum ScatterElementsV2; ReductionType reduction; // 指定更新的方式(none,add) int32_t axis; // 指定更新的轴 diff --git a/src/kernels/include/asdops/params/transdata.h b/src/kernels/include/asdops/params/transdata.h index 9a441b68..6e558e47 100644 --- a/src/kernels/include/asdops/params/transdata.h +++ b/src/kernels/include/asdops/params/transdata.h @@ -18,6 +18,7 @@ namespace AsdOps { namespace OpParam { struct Transdata { enum class TransdataType { UNDEFINED = 0, FRACTAL_NZ_TO_ND, ND_TO_FRACTAL_NZ }; + using enum TransdataType; TransdataType transdataType = TransdataType::UNDEFINED; Mki::SVector outCrops = {0, 0}; diff --git a/src/kernels/include/atbops/params/blockcopy.h b/src/kernels/include/atbops/params/blockcopy.h index 2cdf6282..5219b72f 100644 --- a/src/kernels/include/atbops/params/blockcopy.h +++ b/src/kernels/include/atbops/params/blockcopy.h @@ -18,6 +18,7 @@ struct BlockCopy { BLOCK_COPY_CACHE_ND = 0, BLOCK_COPY_CACHE_NZ = 1 }; + using enum Type; Type type = Type::BLOCK_COPY_CACHE_ND; bool operator==(const BlockCopy &other) const { diff --git a/src/kernels/include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h b/src/kernels/include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h index 7e9591ae..38d853a5 100644 --- a/src/kernels/include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h +++ b/src/kernels/include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h @@ -18,18 +18,21 @@ struct GmmDeqSwigluQuantGmmDeq { OUTPUT_BFLOAT16, OUTPUT_INVALID }; + using enum OutputType; enum class GroupListType { GROUP_LIST_CUMSUM = 0, GROUP_LIST_SINGLE, GROUP_LIST_INVALID }; + using enum GroupListType; enum class WeightUpPermuteType { PERMUTE_N256 = 0, PERMUTE_N128, PERMUTE_INVALID }; + using enum WeightUpPermuteType; OutputType outputType = OutputType::OUTPUT_FLOAT16; GroupListType groupListType = GroupListType::GROUP_LIST_CUMSUM; diff --git a/src/kernels/include/atbops/params/mla.h b/src/kernels/include/atbops/params/mla.h index 1b432591..538c5723 100644 --- a/src/kernels/include/atbops/params/mla.h +++ b/src/kernels/include/atbops/params/mla.h @@ -23,6 +23,7 @@ struct MLA { SPLIT_CACHE = 0, PREFILL_SPLIT_CACHE = 1 }; + using enum Type; Type type; int32_t headSize = 0; float tor = 0; @@ -40,6 +41,7 @@ struct MLA { MASK_TYPE_CAUSAL_MASK = 5, MASK_TYPE_SWA_NORM = 6 }; + using enum MaskType; MaskType maskType = MaskType::MASK_TYPE_NONE; @@ -49,6 +51,7 @@ struct MLA { TYPE_QUANT_QKV_OFFLINE, TYPE_QUANT_QKV_ONLINE }; + using enum QuantType; QuantType quantType = QuantType::TYPE_QUANT_UNDEFINED; std::vector qSeqLen; diff --git a/src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h b/src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h index c31a38b9..821e5057 100644 --- a/src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h +++ b/src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h @@ -18,12 +18,14 @@ struct MmDeqSwigluQuantMmDeq { OUTPUT_BFLOAT16, OUTPUT_INVALID }; + using enum OutputType; enum class WeightUpPermuteType { PERMUTE_N256 = 0, PERMUTE_N128, PERMUTE_INVALID }; + using enum WeightUpPermuteType; OutputType outputType = OutputType::OUTPUT_FLOAT16; WeightUpPermuteType weightUpPermuteType = WeightUpPermuteType::PERMUTE_N256; diff --git a/src/kernels/include/atbops/params/paged_cache_load.h b/src/kernels/include/atbops/params/paged_cache_load.h index ab259989..8147bf40 100644 --- a/src/kernels/include/atbops/params/paged_cache_load.h +++ b/src/kernels/include/atbops/params/paged_cache_load.h @@ -19,6 +19,7 @@ struct PagedCacheLoad { PAGED_CACHE_LOAD_ND = 0, PAGED_CACHE_LOAD_NZ = 1, }; + using enum Type; Type type = Type::PAGED_CACHE_LOAD_ND; bool cuSeqLens = false; bool hasSeqStarts = false; diff --git a/src/kernels/include/atbops/params/ring_mla.h b/src/kernels/include/atbops/params/ring_mla.h index 9d25c729..235f1bf9 100644 --- a/src/kernels/include/atbops/params/ring_mla.h +++ b/src/kernels/include/atbops/params/ring_mla.h @@ -23,6 +23,7 @@ struct RINGMLA { SPLIT_CACHE = 0, PREFILL_SPLIT_CACHE = 1, }; + using enum Type; Type type; int32_t headSize = 0; float tor = 0; @@ -38,6 +39,7 @@ struct RINGMLA { MASK_TYPE_LOOK_AHEAD = 3, MASK_TYPE_MASK_FREE = 4 }; + using enum MaskType; MaskType maskType = MaskType::MASK_TYPE_NONE; @@ -47,6 +49,7 @@ struct RINGMLA { TYPE_QUANT_QKV_OFFLINE, TYPE_QUANT_QKV_ONLINE }; + using enum QuantType; QuantType quantType = QuantType::TYPE_QUANT_UNDEFINED; std::vector qSeqLen; diff --git a/src/kernels/include/atbops/params/unpad_flash_attention_nz.h b/src/kernels/include/atbops/params/unpad_flash_attention_nz.h index fd8e8577..f7277aca 100644 --- a/src/kernels/include/atbops/params/unpad_flash_attention_nz.h +++ b/src/kernels/include/atbops/params/unpad_flash_attention_nz.h @@ -41,6 +41,7 @@ struct UnpadFlashAttentionNz { BMM1_FP32_EXP_FP32 = 1, BMM2_ONLINE_SOFTMAX_FP16 = 4 }; + using enum PrecType; PrecType precType = PrecType::BMM1_FP16_EXP_FP32; // DATA enum DataDimOrder : int { diff --git a/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h b/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h index 46d519d4..0fb7aeed 100644 --- a/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h +++ b/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h @@ -27,6 +27,7 @@ enum class PermuteType{ PERMUTE_N128, PERMUTE_INVALID }; +using enum PermuteType; template struct Gmm1TileArgs { diff --git a/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h b/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h index baabe59c..96203e7f 100644 --- a/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h +++ b/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h @@ -27,6 +27,7 @@ enum class PermuteType{ PERMUTE_N128, PERMUTE_INVALID }; +using enum PermuteType; template struct Mm1TileArgs { -- Gitee