diff --git a/src/kernels/include/asdops/params/activation.h b/src/kernels/include/asdops/params/activation.h index 5bf22fd4dc7f7815d7e47fe6f7d7ca8a07dd9d0f..664baacd838d2eb2b82bdeae4953c06161b9017d 100644 --- a/src/kernels/include/asdops/params/activation.h +++ b/src/kernels/include/asdops/params/activation.h @@ -16,7 +16,7 @@ namespace AsdOps { namespace OpParam { struct Activation { - enum ActivationType { + enum class ActivationType { ACTIVATION_UNDEFINED = 0, ACTIVATION_RELU, ACTIVATION_GELU, @@ -28,6 +28,7 @@ struct Activation { ACTIVATION_SIGMOID, ACTIVATION_FASTER_GELU_FORWARD, }; + using enum ActivationType; ActivationType activationType; float scale = 1.0f; // for Swish diff --git a/src/kernels/include/asdops/params/elewise.h b/src/kernels/include/asdops/params/elewise.h index 86e1b7c0f20b6ae9585154267f0b2c2ce838b700..247f6037f751353d093fa38eb5ee24463c5ca5d7 100644 --- a/src/kernels/include/asdops/params/elewise.h +++ b/src/kernels/include/asdops/params/elewise.h @@ -16,7 +16,7 @@ namespace AsdOps { namespace OpParam { struct Elewise { - enum ElewiseType { + enum class ElewiseType { ELEWISE_UNDEFINED = 0, ELEWISE_CAST, ELEWISE_MULS, @@ -39,6 +39,7 @@ struct Elewise { ELEWISE_DEQUANT_PER_CHANNEL, ELEWISE_DYNAMIC_QUANT, }; + using enum ElewiseType; ElewiseType elewiseType; float varAttr = 0.0f; // MULS diff --git a/src/kernels/include/asdops/params/faUpdate.h b/src/kernels/include/asdops/params/faUpdate.h index 370fd34f0356d31805f634208f6ccb6eaea723b0..4182a01c7098c7deb7a22ec71db4f6d55b3fcc2a 100644 --- a/src/kernels/include/asdops/params/faUpdate.h +++ b/src/kernels/include/asdops/params/faUpdate.h @@ -17,9 +17,10 @@ namespace AsdOps { namespace OpParam { struct FaUpdate { - enum FaUpdateType { + enum class FaUpdateType { DECODE_UPDATE = 0, }; + using enum FaUpdateType; FaUpdateType faUpdateType; uint32_t sp; diff --git a/src/kernels/include/asdops/params/index.h b/src/kernels/include/asdops/params/index.h index ad9f683073c22ec56edf8bc238048fa5c876b66a..a747d4b3046c40b7728bd2a8052c32a2c4460e15 100644 --- a/src/kernels/include/asdops/params/index.h +++ b/src/kernels/include/asdops/params/index.h @@ -16,11 +16,12 @@ namespace AsdOps { namespace OpParam { struct Index { - enum IndexType { + enum class IndexType { INDEX_UNDEFINED = 0, INDEX_ADD, INDEX_ADD_VALID, }; + using enum IndexType; IndexType indexType; int64_t axis = 0; diff --git a/src/kernels/include/asdops/params/norm.h b/src/kernels/include/asdops/params/norm.h index 62fd7c11da26eeb21ccec602152f649060d8cb62..16e8179397a2bd53b1f74aa9f382f64dfa835513 100644 --- a/src/kernels/include/asdops/params/norm.h +++ b/src/kernels/include/asdops/params/norm.h @@ -18,7 +18,7 @@ namespace AsdOps { namespace OpParam { struct Norm { - enum NormType { + enum class NormType { NORM_UNDEFINED = 0, LAYER_NORM, RMS_NORM, @@ -26,6 +26,7 @@ struct Norm { RMS_NORM_BACKWARD, GATHER_PRE_RMS_NORM, }; + using enum NormType; NormType normType; // layernorm int32_t beginNormAxis = 0; diff --git a/src/kernels/include/asdops/params/reduce.h b/src/kernels/include/asdops/params/reduce.h index cfa087872ede2ab8c04e66c0572435dca974c5c3..0d3125286d425112e388260094f6362dfa6d972e 100644 --- a/src/kernels/include/asdops/params/reduce.h +++ b/src/kernels/include/asdops/params/reduce.h @@ -17,12 +17,13 @@ namespace AsdOps { namespace OpParam { struct Reduce { - enum ReduceType { + enum class ReduceType { REDUCE_UNDEFINED = 0, REDUCE_MAX, REDUCE_MIN, REDUCE_SUM, }; + using enum ReduceType; ReduceType reduceType; Mki::SVector axis; diff --git a/src/kernels/include/asdops/params/scatter_elements_v2.h b/src/kernels/include/asdops/params/scatter_elements_v2.h index 5a91b87024af45489b5932f5857ab7e677d97418..f9f441faa1cd5b617e6823300d3172ddb0b081ce 100644 --- a/src/kernels/include/asdops/params/scatter_elements_v2.h +++ b/src/kernels/include/asdops/params/scatter_elements_v2.h @@ -21,10 +21,11 @@ namespace OpParam { * @brief ScatterElementsV2 算子的参数结构体 */ struct ScatterElementsV2 { - enum ReductionType { + enum class ReductionType { NONE = 0, ADD, }; + using enum ScatterElementsV2; ReductionType reduction; // 指定更新的方式(none,add) int32_t axis; // 指定更新的轴 diff --git a/src/kernels/include/asdops/params/transdata.h b/src/kernels/include/asdops/params/transdata.h index 8e8fa15cf9ec49b3f30e465dfac409e8ac8bf051..6e558e4728baccc55f0bcf7e545f9e1b63c3b11e 100644 --- a/src/kernels/include/asdops/params/transdata.h +++ b/src/kernels/include/asdops/params/transdata.h @@ -17,8 +17,9 @@ namespace AsdOps { namespace OpParam { struct Transdata { - enum TransdataType { UNDEFINED = 0, FRACTAL_NZ_TO_ND, ND_TO_FRACTAL_NZ }; - TransdataType transdataType = UNDEFINED; + enum class TransdataType { UNDEFINED = 0, FRACTAL_NZ_TO_ND, ND_TO_FRACTAL_NZ }; + using enum TransdataType; + TransdataType transdataType = TransdataType::UNDEFINED; Mki::SVector outCrops = {0, 0}; bool operator==(const Transdata &other) const diff --git a/src/kernels/include/atbops/params/blockcopy.h b/src/kernels/include/atbops/params/blockcopy.h index ba438dba37e4f67622a7698e7d8613871a631059..5219b72f53af089fb2db8642fb2da4552822b7f0 100644 --- a/src/kernels/include/atbops/params/blockcopy.h +++ b/src/kernels/include/atbops/params/blockcopy.h @@ -14,11 +14,12 @@ namespace AtbOps { namespace OpParam { struct BlockCopy { - enum Type { + enum class Type { BLOCK_COPY_CACHE_ND = 0, BLOCK_COPY_CACHE_NZ = 1 }; - Type type = BLOCK_COPY_CACHE_ND; + using enum Type; + Type type = Type::BLOCK_COPY_CACHE_ND; bool operator==(const BlockCopy &other) const { return this->type == other.type; diff --git a/src/kernels/include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h b/src/kernels/include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h index 4e1e86cc670ef53ab858ae0a9a9450f3c7888dec..38d853a520be9228839894e432a1cd18123e1bdf 100644 --- a/src/kernels/include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h +++ b/src/kernels/include/atbops/params/gmm_deq_swiglu_quant_gmm_deq.h @@ -13,27 +13,30 @@ namespace AtbOps { namespace OpParam { struct GmmDeqSwigluQuantGmmDeq { - enum OutputType { + enum class OutputType { OUTPUT_FLOAT16 = 0, OUTPUT_BFLOAT16, OUTPUT_INVALID }; + using enum OutputType; - enum GroupListType { + enum class GroupListType { GROUP_LIST_CUMSUM = 0, GROUP_LIST_SINGLE, GROUP_LIST_INVALID }; + using enum GroupListType; - enum WeightUpPermuteType { + enum class WeightUpPermuteType { PERMUTE_N256 = 0, PERMUTE_N128, PERMUTE_INVALID }; + using enum WeightUpPermuteType; - OutputType outputType = OUTPUT_FLOAT16; - GroupListType groupListType = GROUP_LIST_CUMSUM; - WeightUpPermuteType weightUpPermuteType = PERMUTE_N256; + OutputType outputType = OutputType::OUTPUT_FLOAT16; + GroupListType groupListType = GroupListType::GROUP_LIST_CUMSUM; + WeightUpPermuteType weightUpPermuteType = WeightUpPermuteType::PERMUTE_N256; bool transposeWeightUp = false; bool transposeWeightDown = true; diff --git a/src/kernels/include/atbops/params/mla.h b/src/kernels/include/atbops/params/mla.h index 9b9b363e6e973f887ce15080b3efbcac82ddea64..538c5723766b912d604b1c5955c8132655daf6b5 100644 --- a/src/kernels/include/atbops/params/mla.h +++ b/src/kernels/include/atbops/params/mla.h @@ -19,10 +19,11 @@ namespace AtbOps { namespace OpParam { struct MLA { - enum Type { + enum class Type { SPLIT_CACHE = 0, PREFILL_SPLIT_CACHE = 1 }; + using enum Type; Type type; int32_t headSize = 0; float tor = 0; @@ -31,7 +32,7 @@ struct MLA { std::vector kTensorList; std::vector vTensorList; - enum MaskType { + enum class MaskType { MASK_TYPE_NONE = 0, MASK_TYPE_NORM = 1, MASK_TYPE_ALIBI = 2, @@ -40,16 +41,18 @@ struct MLA { MASK_TYPE_CAUSAL_MASK = 5, MASK_TYPE_SWA_NORM = 6 }; + using enum MaskType; - MaskType maskType = MASK_TYPE_NONE; + MaskType maskType = MaskType::MASK_TYPE_NONE; - enum QuantType { + enum class QuantType { TYPE_QUANT_UNDEFINED = 0, TYPE_DEQUANT_FUSION, TYPE_QUANT_QKV_OFFLINE, TYPE_QUANT_QKV_ONLINE }; - QuantType quantType = TYPE_QUANT_UNDEFINED; + using enum QuantType; + QuantType quantType = QuantType::TYPE_QUANT_UNDEFINED; std::vector qSeqLen; std::vector kvSeqLen; diff --git a/src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h b/src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h index 9ecb22709004c8680cd6ccc63beee81700c03c4c..821e5057ffbbd0697d3aa681062050100e718fde 100644 --- a/src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h +++ b/src/kernels/include/atbops/params/mm_deq_swiglu_quant_mm_deq.h @@ -13,20 +13,22 @@ namespace AtbOps { namespace OpParam { struct MmDeqSwigluQuantMmDeq { - enum OutputType { + enum class OutputType { OUTPUT_FLOAT16 = 0, OUTPUT_BFLOAT16, OUTPUT_INVALID }; + using enum OutputType; - enum WeightUpPermuteType { + enum class WeightUpPermuteType { PERMUTE_N256 = 0, PERMUTE_N128, PERMUTE_INVALID }; + using enum WeightUpPermuteType; - OutputType outputType = OUTPUT_FLOAT16; - WeightUpPermuteType weightUpPermuteType = PERMUTE_N256; + OutputType outputType = OutputType::OUTPUT_FLOAT16; + WeightUpPermuteType weightUpPermuteType = WeightUpPermuteType::PERMUTE_N256; bool transposeWeightUp = false; bool transposeWeightDown = true; diff --git a/src/kernels/include/atbops/params/paged_cache_load.h b/src/kernels/include/atbops/params/paged_cache_load.h index c6ebed8f744dd6c6775fecde0d9396afc1bd9d3c..8147bf40ad1e1a10d76874519f4f158d34aed5df 100644 --- a/src/kernels/include/atbops/params/paged_cache_load.h +++ b/src/kernels/include/atbops/params/paged_cache_load.h @@ -15,11 +15,12 @@ namespace AtbOps { namespace OpParam { struct PagedCacheLoad { - enum Type { + enum class Type { PAGED_CACHE_LOAD_ND = 0, PAGED_CACHE_LOAD_NZ = 1, }; - Type type = PAGED_CACHE_LOAD_ND; + using enum Type; + Type type = Type::PAGED_CACHE_LOAD_ND; bool cuSeqLens = false; bool hasSeqStarts = false; diff --git a/src/kernels/include/atbops/params/ring_mla.h b/src/kernels/include/atbops/params/ring_mla.h index e2dfc51ca1924d74bca27cc12174ae6d860badf2..235f1bf9ac883779f701c8907520ae6d6ac9b158 100644 --- a/src/kernels/include/atbops/params/ring_mla.h +++ b/src/kernels/include/atbops/params/ring_mla.h @@ -19,10 +19,11 @@ namespace AtbOps { namespace OpParam { struct RINGMLA { - enum Type { + enum class Type { SPLIT_CACHE = 0, PREFILL_SPLIT_CACHE = 1, }; + using enum Type; Type type; int32_t headSize = 0; float tor = 0; @@ -31,23 +32,25 @@ struct RINGMLA { std::vector kTensorList; std::vector vTensorList; - enum MaskType { + enum class MaskType { MASK_TYPE_NONE = 0, MASK_TYPE_NORM = 1, MASK_TYPE_ALIBI = 2, MASK_TYPE_LOOK_AHEAD = 3, MASK_TYPE_MASK_FREE = 4 }; + using enum MaskType; - MaskType maskType = MASK_TYPE_NONE; + MaskType maskType = MaskType::MASK_TYPE_NONE; - enum QuantType { + enum class QuantType { TYPE_QUANT_UNDEFINED = 0, TYPE_DEQUANT_FUSION, TYPE_QUANT_QKV_OFFLINE, TYPE_QUANT_QKV_ONLINE }; - QuantType quantType = TYPE_QUANT_UNDEFINED; + using enum QuantType; + QuantType quantType = QuantType::TYPE_QUANT_UNDEFINED; std::vector qSeqLen; std::vector kvSeqLen; diff --git a/src/kernels/include/atbops/params/unpad_flash_attention_nz.h b/src/kernels/include/atbops/params/unpad_flash_attention_nz.h index aed8b578ff71f40bd2699d6e255cc72f31cdfdaa..f7277aca2b970cac9b3526b648addb2f33f88254 100644 --- a/src/kernels/include/atbops/params/unpad_flash_attention_nz.h +++ b/src/kernels/include/atbops/params/unpad_flash_attention_nz.h @@ -36,12 +36,13 @@ struct UnpadFlashAttentionNz { SCALE_LOGN = 1 }; ScaleType scaleType = SCALE_TOR; - enum PrecType { + enum class PrecType { BMM1_FP16_EXP_FP32 = 0, BMM1_FP32_EXP_FP32 = 1, BMM2_ONLINE_SOFTMAX_FP16 = 4 }; - PrecType precType = BMM1_FP16_EXP_FP32; + using enum PrecType; + PrecType precType = PrecType::BMM1_FP16_EXP_FP32; // DATA enum DataDimOrder : int { TYPE_BSND = 0, diff --git a/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h b/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h index d821ba1db659371f7bc1cfb653e33121b1300328..0fb7aeedab39174f3edd59f550e66747a945023d 100644 --- a/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h +++ b/src/kernels/mixkernels/gmm_deq_swiglu_quant_gmm_deq/tiling/gmm_deq_swiglu_quant_gmm_deq_tiling_data.h @@ -22,25 +22,26 @@ constexpr uint32_t CUSTOM_L0C_STAGES = 1; constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true; constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true; -enum PermuteType { +enum class PermuteType{ PERMUTE_N256, PERMUTE_N128, PERMUTE_INVALID }; +using enum PermuteType; template struct Gmm1TileArgs { }; template <> -struct Gmm1TileArgs { +struct Gmm1TileArgs { static constexpr uint32_t L1M = 128; static constexpr uint32_t L1N = 256; static constexpr uint32_t EPIM = 32; }; template <> -struct Gmm1TileArgs { +struct Gmm1TileArgs { static constexpr uint32_t L1M = 256; static constexpr uint32_t L1N = 128; static constexpr uint32_t EPIM = 64; diff --git a/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h b/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h index 3c22980d66ce6927a3e3ee9b81c29a787cb404b8..96203e7f77f7c93846c79a9c3865d9eb8952b3e2 100644 --- a/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h +++ b/src/kernels/mixkernels/mm_deq_swiglu_quant_mm_deq/tiling/mm_deq_swiglu_quant_mm_deq_tiling_data.h @@ -22,25 +22,26 @@ constexpr uint32_t CUSTOM_L0C_STAGES = 1; constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true; constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true; -enum PermuteType { +enum class PermuteType{ PERMUTE_N256, PERMUTE_N128, PERMUTE_INVALID }; +using enum PermuteType; template struct Mm1TileArgs { }; template <> -struct Mm1TileArgs { +struct Mm1TileArgs { static constexpr uint32_t L1M = 128; static constexpr uint32_t L1N = 256; static constexpr uint32_t EPIM = 32; }; template <> -struct Mm1TileArgs { +struct Mm1TileArgs { static constexpr uint32_t L1M = 256; static constexpr uint32_t L1N = 128; static constexpr uint32_t EPIM = 64;