From 6887c605eae29cfdfdff4bf18bc5df92189f032b Mon Sep 17 00:00:00 2001 From: zouyanlong Date: Wed, 6 Aug 2025 14:59:21 +0800 Subject: [PATCH 1/2] fix --- CMakeLists.txt | 4 +- scripts/build.sh | 36 +- tests/CMakeLists.txt | 3 + tests/proftest/CMakeLists.txt | 61 + .../core/base/context_factory.cpp | 99 ++ .../core/base/event_manager.cpp | 227 +++ .../core/base/external_comm_manager.cpp | 265 ++++ .../layer_test_framework/core/base/model.cpp | 768 ++++++++++ .../include/atb_speed/base/context_factory.h | 32 + .../include/atb_speed/base/event_manager.h | 165 ++ .../atb_speed/base/external_comm_manager.h | 84 + .../atb_speed/base/hosttensor_binder.h | 30 + .../core/include/atb_speed/base/model.h | 174 +++ .../core/include/atb_speed/log.h | 53 + .../atb_speed/utils/ModelTaskExecutor.h | 65 + .../core/include/atb_speed/utils/TaskQueue.h | 40 + .../core/include/atb_speed/utils/check_util.h | 80 + .../core/include/atb_speed/utils/config.h | 39 + .../include/atb_speed/utils/file_system.h | 35 + .../include/atb_speed/utils/hccl_runner.h | 48 + .../core/include/atb_speed/utils/match.h | 24 + .../include/atb_speed/utils/model_factory.h | 47 + .../atb_speed/utils/operation_factory.h | 51 + .../include/atb_speed/utils/operation_util.h | 68 + .../include/atb_speed/utils/share_memory.h | 46 + .../core/include/atb_speed/utils/singleton.h | 33 + .../include/atb_speed/utils/speed_probe.h | 33 + .../core/include/atb_speed/utils/statistic.h | 42 + .../core/include/atb_speed/utils/str_split.h | 23 + .../include/atb_speed/utils/tensor_util.h | 31 + .../core/include/atb_speed/utils/timer.h | 35 + .../core/utils/ModelTaskExecutor.cpp | 68 + .../core/utils/TaskQueue.cpp | 36 + .../core/utils/check_util.cpp | 123 ++ .../core/utils/config.cpp | 57 + .../core/utils/file_system.cpp | 129 ++ .../core/utils/hccl_runner.cpp | 144 ++ .../layer_test_framework/core/utils/match.cpp | 36 + .../core/utils/model_factory.cpp | 51 + .../core/utils/operation_factory.cpp | 51 + .../core/utils/share_memory.cpp | 131 ++ .../core/utils/speed_probe.cpp | 33 + .../core/utils/statistic.cpp | 52 + .../core/utils/str_split.cpp | 56 + .../core/utils/tensor_util.cpp | 78 + .../layer_test_framework/core/utils/timer.cpp | 44 + .../models/base/layer/decoder_layer.cpp | 975 ++++++++++++ .../models/base/layer/decoder_layer.h | 187 +++ .../models/base/model/decoder_model.cpp | 1362 +++++++++++++++++ .../models/base/model/decoder_model.h | 264 ++++ .../models/base/param/dynamic_param.h | 60 + .../models/base/param/layer_param.cpp | 63 + .../models/base/param/layer_param.h | 74 + .../models/base/param/mapping.cpp | 105 ++ .../models/base/param/mapping.h | 107 ++ .../models/base/param/model_param.cpp | 364 +++++ .../models/base/param/model_param.h | 145 ++ .../models/base/param/param.h | 225 +++ .../models/base/param/param_utils.h | 50 + .../bloom/layer/bloom_decoder_layer.cpp | 50 + .../models/bloom/layer/bloom_decoder_layer.h | 42 + .../bloom/model/bloom_decoder_model.cpp | 82 + .../models/bloom/model/bloom_decoder_model.h | 45 + .../aclnn/core/acl_nn_global_cache.cpp | 129 ++ .../aclnn/core/acl_nn_global_cache.h | 87 ++ .../aclnn/core/acl_nn_operation.cpp | 255 +++ .../operations/aclnn/core/acl_nn_operation.h | 106 ++ .../aclnn/core/acl_nn_operation_cache.cpp | 134 ++ .../aclnn/core/acl_nn_operation_cache.h | 63 + .../operations/aclnn/core/acl_nn_tensor.h | 68 + .../aclnn/core/executor_manager.cpp | 72 + .../operations/aclnn/core/executor_manager.h | 61 + .../add_rms_norm_dynamic_quant_operation.cpp | 161 ++ .../add_rms_norm_dynamic_quant_operation.h | 40 + .../aclnn/ops/add_rms_norm_operation.cpp | 140 ++ .../aclnn/ops/add_rms_norm_operation.h | 83 + .../ops/add_rms_norm_quant_operation.cpp | 166 ++ .../aclnn/ops/add_rms_norm_quant_operation.h | 40 + .../operations/aclnn/ops/argmax_operation.cpp | 157 ++ .../operations/aclnn/ops/argmax_operation.h | 53 + .../aclnn/ops/argsort_operation.cpp | 121 ++ .../operations/aclnn/ops/argsort_operation.h | 75 + .../operations/aclnn/ops/attn_operation.cpp | 199 +++ .../operations/aclnn/ops/attn_operation.h | 115 ++ .../operations/aclnn/ops/cast_operation.cpp | 166 ++ .../operations/aclnn/ops/cast_operation.h | 53 + .../operations/aclnn/ops/concat_operation.cpp | 142 ++ .../operations/aclnn/ops/concat_operation.h | 52 + .../dequant_rope_quant_kvcache_operation.cpp | 190 +++ .../dequant_rope_quant_kvcache_operation.h | 52 + .../ops/dequant_swiglu_quant_operation.cpp | 199 +++ .../ops/dequant_swiglu_quant_operation.h | 50 + .../aclnn/ops/dynamic_quant_operation.cpp | 98 ++ .../aclnn/ops/dynamic_quant_operation.h | 68 + .../aclnn/ops/finalize_routing_operation.cpp | 92 ++ .../aclnn/ops/finalize_routing_operation.h | 94 ++ .../operations/aclnn/ops/gelu_operation.cpp | 214 +++ .../operations/aclnn/ops/gelu_operation.h | 91 ++ .../aclnn/ops/grouped_matmul_operation.cpp | 335 ++++ .../aclnn/ops/grouped_matmul_operation.h | 120 ++ .../ops/grouped_matmul_swiglu_operation.cpp | 173 +++ .../ops/grouped_matmul_swiglu_operation.h | 105 ++ .../aclnn/ops/index_select_operation.cpp | 117 ++ .../aclnn/ops/index_select_operation.h | 89 ++ .../aclnn/ops/indexput_operation.cpp | 104 ++ .../operations/aclnn/ops/indexput_operation.h | 96 ++ .../ops/inplace_nan_to_num_operation.cpp | 125 ++ .../aclnn/ops/inplace_nan_to_num_operation.h | 87 ++ .../inplacemasked_filltensor_operation.cpp | 123 ++ .../ops/inplacemasked_filltensor_operation.h | 81 + .../aclnn/ops/layer_norm_operation.cpp | 216 +++ .../aclnn/ops/layer_norm_operation.h | 107 ++ .../operations/aclnn/ops/len_operation.cpp | 144 ++ .../operations/aclnn/ops/len_operation.h | 33 + .../aclnn/ops/matmul_allreduce_operation.cpp | 186 +++ .../aclnn/ops/matmul_allreduce_operation.h | 43 + .../operations/aclnn/ops/matmul_operation.cpp | 227 +++ .../operations/aclnn/ops/matmul_operation.h | 54 + .../operations/aclnn/ops/max_v2_operation.cpp | 165 ++ .../operations/aclnn/ops/max_v2_operation.h | 52 + .../aclnn/ops/minimum_operation.cpp | 140 ++ .../operations/aclnn/ops/minimum_operation.h | 33 + .../moe_compute_expert_tokens_operation.cpp | 89 ++ .../ops/moe_compute_expert_tokens_operation.h | 78 + .../ops/moe_distribute_combine_operation.cpp | 133 ++ .../ops/moe_distribute_combine_operation.h | 63 + .../moe_distribute_combine_v2_operation.cpp | 123 ++ .../ops/moe_distribute_combine_v2_operation.h | 52 + .../ops/moe_distribute_dispatch_operation.cpp | 194 +++ .../ops/moe_distribute_dispatch_operation.h | 66 + .../moe_distribute_dispatch_v2_operation.cpp | 184 +++ .../moe_distribute_dispatch_v2_operation.h | 55 + .../aclnn/ops/moe_init_routing_operation.cpp | 105 ++ .../aclnn/ops/moe_init_routing_operation.h | 97 ++ .../ops/moe_init_routing_quant_operation.cpp | 119 ++ .../ops/moe_init_routing_quant_operation.h | 104 ++ .../aclnn/ops/moe_topk_softmax_operation.cpp | 103 ++ .../aclnn/ops/moe_topk_softmax_operation.h | 85 + .../ops/moetoken_umpermute_operation.cpp | 111 ++ .../aclnn/ops/moetoken_unpermute_operation.h | 77 + .../ops/obfuscation_calculate_operation.cpp | 104 ++ .../ops/obfuscation_calculate_operation.h | 46 + .../aclnn/ops/obfuscation_setup_operation.cpp | 93 ++ .../aclnn/ops/obfuscation_setup_operation.h | 48 + .../ops/prompt_flash_attention_operation.cpp | 168 ++ .../ops/prompt_flash_attention_operation.h | 135 ++ .../ops/quant_batch_matmul_operation.cpp | 186 +++ .../aclnn/ops/quant_batch_matmul_operation.h | 62 + .../aclnn/ops/quant_gmm_dequant_operation.cpp | 198 +++ .../aclnn/ops/quant_gmm_dequant_operation.h | 52 + .../operations/aclnn/ops/repeat_operation.cpp | 105 ++ .../operations/aclnn/ops/repeat_operation.h | 80 + .../aclnn/ops/rms_norm_operation.cpp | 92 ++ .../operations/aclnn/ops/rms_norm_operation.h | 26 + .../aclnn/ops/scatter_operation.cpp | 208 +++ .../operations/aclnn/ops/scatter_operation.h | 57 + .../aclnn/ops/sigmoid_operation.cpp | 125 ++ .../operations/aclnn/ops/sigmoid_operation.h | 76 + .../aclnn/ops/split_with_size_operation.cpp | 180 +++ .../aclnn/ops/split_with_size_operation.h | 53 + .../operations/aclnn/ops/std_operation.cpp | 145 ++ .../operations/aclnn/ops/std_operation.h | 73 + .../aclnn/ops/vector_norm_operation.cpp | 226 +++ .../aclnn/ops/vector_norm_operation.h | 121 ++ .../operations/aclnn/ops/w16a16_operation.cpp | 208 +++ .../operations/aclnn/ops/w16a16_operation.h | 105 ++ .../operations/aclnn/ops/w4a16_operation.cpp | 45 + .../operations/aclnn/ops/w4a16_operation.h | 101 ++ .../operations/aclnn/ops/w4a8_operation.cpp | 109 ++ .../operations/aclnn/ops/w4a8_operation.h | 42 + .../operations/aclnn/ops/w8a16_operation.cpp | 35 + .../operations/aclnn/ops/w8a16_operation.h | 98 ++ .../operations/aclnn/ops/w8a8_operation.cpp | 241 +++ .../operations/aclnn/ops/w8a8_operation.h | 117 ++ .../operations/aclnn/utils/utils.cpp | 270 ++++ .../operations/aclnn/utils/utils.h | 111 ++ .../operations/aclrt/ops/aclrt_cmo_async.cpp | 125 ++ .../operations/aclrt/ops/aclrt_cmo_async.h | 51 + .../fusion/attention/attention_edge.cpp | 841 ++++++++++ .../fusion/attention/attention_edge.h | 49 + .../fusion/attention/fusion_attention.cpp | 636 ++++++++ .../fusion/attention/fusion_attention.h | 316 ++++ .../fusion/attention/qkv_linear_split.cpp | 586 +++++++ .../fusion/attention/qkv_linear_split.h | 100 ++ .../fusion/attention/self_attention.cpp | 364 +++++ .../fusion/attention/self_attention.h | 117 ++ .../operations/fusion/common_op_base.h | 34 + .../fusion/embedding/positional_embedding.cpp | 298 ++++ .../fusion/embedding/positional_embedding.h | 155 ++ .../fusion/embedding/word_embedding.cpp | 100 ++ .../fusion/embedding/word_embedding.h | 80 + .../fusion/infer_shape_functions.cpp | 104 ++ .../operations/fusion/infer_shape_functions.h | 45 + .../operations/fusion/linear/linear.cpp | 690 +++++++++ .../operations/fusion/linear/linear.h | 216 +++ .../fusion/linear/linear_parallel.cpp | 662 ++++++++ .../fusion/linear/linear_parallel.h | 133 ++ .../fusion/lmhead/hidden_state_slice.cpp | 61 + .../fusion/lmhead/hidden_state_slice.h | 37 + .../operations/fusion/lmhead/lmhead.cpp | 313 ++++ .../operations/fusion/lmhead/lmhead.h | 100 ++ .../fusion/lmhead/parallel_lmhead_all2all.cpp | 147 ++ .../fusion/lmhead/parallel_lmhead_all2all.h | 18 + .../operations/fusion/mlp/mlp.cpp | 701 +++++++++ .../operations/fusion/mlp/mlp.h | 234 +++ .../operations/fusion/mlp_gate.cpp | 151 ++ .../operations/fusion/mlp_gate.h | 129 ++ .../operations/fusion/mlp_gate_v2.cpp | 202 +++ .../operations/fusion/mlp_gate_v2.h | 45 + .../fusion/moe/device_limited_routing.cpp | 211 +++ .../fusion/moe/device_limited_routing.h | 38 + .../fusion/moe/ep/all_to_all_collect.cpp | 219 +++ .../fusion/moe/ep/all_to_all_collect.h | 53 + .../fusion/moe/ep/all_to_all_dispatch.cpp | 196 +++ .../fusion/moe/ep/all_to_all_dispatch.h | 46 + .../fusion/moe/ep/all_to_all_meta.cpp | 330 ++++ .../fusion/moe/ep/all_to_all_meta.h | 38 + .../fusion/moe/ep/data_preparation.cpp | 166 ++ .../fusion/moe/ep/data_preparation.h | 36 + .../fusion/moe/ep/dynamic_ep_moe.cpp | 482 ++++++ .../operations/fusion/moe/ep/dynamic_ep_moe.h | 91 ++ .../fusion/moe/ep/expert_filter.cpp | 264 ++++ .../operations/fusion/moe/ep/expert_filter.h | 35 + .../fusion/moe/ep/fused_alltoall_gmm.cpp | 271 ++++ .../fusion/moe/ep/fused_alltoall_gmm.h | 45 + .../operations/fusion/moe/integrated_gmm.cpp | 393 +++++ .../operations/fusion/moe/integrated_gmm.h | 69 + .../operations/fusion/moe/moe_mlp.cpp | 1065 +++++++++++++ .../operations/fusion/moe/moe_mlp.h | 86 ++ .../fusion/moe/moe_shared_expert.cpp | 367 +++++ .../operations/fusion/moe/moe_shared_expert.h | 64 + .../operations/fusion/moe/sparse_moe.cpp | 1074 +++++++++++++ .../operations/fusion/moe/sparse_moe.h | 161 ++ .../operations/fusion/norm/norm_linear.cpp | 425 +++++ .../operations/fusion/norm/norm_linear.h | 157 ++ .../operations/fusion/parallel_info.cpp | 83 + .../operations/fusion/parallel_info.h | 58 + .../operations/fusion/parallel_layer.cpp | 163 ++ .../operations/fusion/parallel_layer.h | 101 ++ .../operations/fusion/parallel_layer_v2.cpp | 289 ++++ .../operations/fusion/parallel_layer_v2.h | 62 + .../operations/fusion/parallel_lmhead.cpp | 124 ++ .../operations/fusion/parallel_lmhead.h | 69 + .../operations/fusion/utils.cpp | 378 +++++ .../operations/fusion/utils.h | 317 ++++ tests/proftest/main.cpp | 143 ++ tests/proftest/test_cases/bloom_7b/main.cpp | 219 +++ tests/proftest/utils/include/context_utils.h | 21 + tests/proftest/utils/include/tensor_utils.h | 59 + tests/proftest/utils/include/type_utils.h | 26 + tests/proftest/utils/src/context_utils.cpp | 29 + tests/proftest/utils/src/tensor_utils.cpp | 627 ++++++++ tests/proftest/utils/src/type_utils.cpp | 131 ++ 253 files changed, 37969 insertions(+), 2 deletions(-) create mode 100644 tests/proftest/CMakeLists.txt create mode 100644 tests/proftest/layer_test_framework/core/base/context_factory.cpp create mode 100644 tests/proftest/layer_test_framework/core/base/event_manager.cpp create mode 100644 tests/proftest/layer_test_framework/core/base/external_comm_manager.cpp create mode 100644 tests/proftest/layer_test_framework/core/base/model.cpp create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/base/context_factory.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/base/event_manager.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/base/external_comm_manager.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/base/hosttensor_binder.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/base/model.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/log.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/ModelTaskExecutor.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/TaskQueue.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/check_util.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/config.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/file_system.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/hccl_runner.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/match.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/model_factory.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/operation_factory.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/operation_util.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/share_memory.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/singleton.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/speed_probe.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/statistic.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/str_split.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/tensor_util.h create mode 100644 tests/proftest/layer_test_framework/core/include/atb_speed/utils/timer.h create mode 100644 tests/proftest/layer_test_framework/core/utils/ModelTaskExecutor.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/TaskQueue.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/check_util.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/config.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/file_system.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/hccl_runner.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/match.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/model_factory.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/operation_factory.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/share_memory.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/speed_probe.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/statistic.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/str_split.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/tensor_util.cpp create mode 100644 tests/proftest/layer_test_framework/core/utils/timer.cpp create mode 100644 tests/proftest/layer_test_framework/models/base/layer/decoder_layer.cpp create mode 100644 tests/proftest/layer_test_framework/models/base/layer/decoder_layer.h create mode 100644 tests/proftest/layer_test_framework/models/base/model/decoder_model.cpp create mode 100644 tests/proftest/layer_test_framework/models/base/model/decoder_model.h create mode 100644 tests/proftest/layer_test_framework/models/base/param/dynamic_param.h create mode 100644 tests/proftest/layer_test_framework/models/base/param/layer_param.cpp create mode 100644 tests/proftest/layer_test_framework/models/base/param/layer_param.h create mode 100644 tests/proftest/layer_test_framework/models/base/param/mapping.cpp create mode 100644 tests/proftest/layer_test_framework/models/base/param/mapping.h create mode 100644 tests/proftest/layer_test_framework/models/base/param/model_param.cpp create mode 100644 tests/proftest/layer_test_framework/models/base/param/model_param.h create mode 100644 tests/proftest/layer_test_framework/models/base/param/param.h create mode 100644 tests/proftest/layer_test_framework/models/base/param/param_utils.h create mode 100644 tests/proftest/layer_test_framework/models/bloom/layer/bloom_decoder_layer.cpp create mode 100644 tests/proftest/layer_test_framework/models/bloom/layer/bloom_decoder_layer.h create mode 100644 tests/proftest/layer_test_framework/models/bloom/model/bloom_decoder_model.cpp create mode 100644 tests/proftest/layer_test_framework/models/bloom/model/bloom_decoder_model.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_global_cache.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_global_cache.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation_cache.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation_cache.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_tensor.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/core/executor_manager.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/core/executor_manager.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_dynamic_quant_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_dynamic_quant_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_quant_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_quant_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/argmax_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/argmax_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/argsort_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/argsort_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/attn_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/attn_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/cast_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/cast_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/concat_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/concat_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_rope_quant_kvcache_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_rope_quant_kvcache_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_swiglu_quant_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_swiglu_quant_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/dynamic_quant_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/dynamic_quant_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/finalize_routing_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/finalize_routing_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/gelu_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/gelu_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_swiglu_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_swiglu_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/index_select_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/index_select_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/indexput_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/indexput_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/inplace_nan_to_num_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/inplace_nan_to_num_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/inplacemasked_filltensor_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/inplacemasked_filltensor_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/layer_norm_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/layer_norm_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/len_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/len_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_allreduce_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_allreduce_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/max_v2_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/max_v2_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/minimum_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/minimum_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_compute_expert_tokens_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_compute_expert_tokens_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_v2_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_v2_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_v2_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_v2_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_quant_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_quant_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_topk_softmax_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moe_topk_softmax_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moetoken_umpermute_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/moetoken_unpermute_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_calculate_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_calculate_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_setup_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_setup_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/prompt_flash_attention_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/prompt_flash_attention_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/quant_batch_matmul_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/quant_batch_matmul_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/quant_gmm_dequant_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/quant_gmm_dequant_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/repeat_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/repeat_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/rms_norm_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/rms_norm_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/scatter_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/scatter_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/sigmoid_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/sigmoid_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/split_with_size_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/split_with_size_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/std_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/std_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/vector_norm_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/vector_norm_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/w16a16_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/w16a16_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/w4a16_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/w4a16_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/w4a8_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/w4a8_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/w8a16_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/w8a16_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/w8a8_operation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/ops/w8a8_operation.h create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/utils/utils.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclnn/utils/utils.h create mode 100644 tests/proftest/layer_test_framework/operations/aclrt/ops/aclrt_cmo_async.cpp create mode 100644 tests/proftest/layer_test_framework/operations/aclrt/ops/aclrt_cmo_async.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/attention/attention_edge.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/attention/attention_edge.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/attention/fusion_attention.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/attention/fusion_attention.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/attention/qkv_linear_split.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/attention/qkv_linear_split.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/attention/self_attention.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/attention/self_attention.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/common_op_base.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/embedding/positional_embedding.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/embedding/positional_embedding.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/embedding/word_embedding.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/embedding/word_embedding.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/infer_shape_functions.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/infer_shape_functions.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/linear/linear.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/linear/linear.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/linear/linear_parallel.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/linear/linear_parallel.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/lmhead/hidden_state_slice.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/lmhead/hidden_state_slice.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/lmhead/lmhead.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/lmhead/lmhead.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/lmhead/parallel_lmhead_all2all.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/lmhead/parallel_lmhead_all2all.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/mlp/mlp.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/mlp/mlp.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/mlp_gate.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/mlp_gate.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/mlp_gate_v2.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/mlp_gate_v2.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/device_limited_routing.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/device_limited_routing.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_collect.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_collect.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_dispatch.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_dispatch.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_meta.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_meta.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/data_preparation.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/data_preparation.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/dynamic_ep_moe.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/dynamic_ep_moe.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/expert_filter.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/expert_filter.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/fused_alltoall_gmm.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/ep/fused_alltoall_gmm.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/integrated_gmm.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/integrated_gmm.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/moe_mlp.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/moe_mlp.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/moe_shared_expert.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/moe_shared_expert.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/sparse_moe.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/moe/sparse_moe.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/norm/norm_linear.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/norm/norm_linear.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/parallel_info.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/parallel_info.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/parallel_layer.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/parallel_layer.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/parallel_layer_v2.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/parallel_layer_v2.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/parallel_lmhead.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/parallel_lmhead.h create mode 100644 tests/proftest/layer_test_framework/operations/fusion/utils.cpp create mode 100644 tests/proftest/layer_test_framework/operations/fusion/utils.h create mode 100644 tests/proftest/main.cpp create mode 100644 tests/proftest/test_cases/bloom_7b/main.cpp create mode 100644 tests/proftest/utils/include/context_utils.h create mode 100644 tests/proftest/utils/include/tensor_utils.h create mode 100644 tests/proftest/utils/include/type_utils.h create mode 100644 tests/proftest/utils/src/context_utils.cpp create mode 100644 tests/proftest/utils/src/tensor_utils.cpp create mode 100644 tests/proftest/utils/src/type_utils.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index ff926af1..d12ed339 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,7 @@ option(USE_PYTHON_TEST "USE_PYTHON_TEST" OFF) option(USE_FUZZ_TEST "USE_FUZZ_TEST" OFF) option(USE_CSV_OPS_TEST "USE_CSV_OPS_TEST" OFF) option(USE_INFRA_TEST "USE_INFRA_TEST" OFF) +option(USE_PROF_TEST "USE_PROF_TEST" OFF) option(USE_TORCH_ATB_TEST "USE_TORCH_ATB_TEST" OFF) option(USE_CXX11_ABI "USE_CXX11_ABI" ON) option(USE_ASAN "USE_ASAN" OFF) @@ -39,6 +40,7 @@ message(STATUS "USE_PYTHON_TEST:${USE_PYTHON_TEST}") message(STATUS "USE_FUZZ_TEST:${USE_FUZZ_TEST}") message(STATUS "USE_CSV_OPS_TEST:${USE_CSV_OPS_TEST}") message(STATUS "USE_INFRA_TEST:${USE_INFRA_TEST}") +message(STATUS "USE_PROF_TEST:${USE_PROF_TEST}") message(STATUS "USE_TORCH_ATB_TEST:${USE_TORCH_ATB_TEST}") message(STATUS "USE_CXX11_ABI:${USE_CXX11_ABI}") message(STATUS "USE_ASAN:${USE_ASAN}") @@ -102,7 +104,7 @@ link_directories( $ENV{PYTORCH_INSTALL_PATH}/lib $ENV{PYTORCH_NPU_INSTALL_PATH}/lib) -if(BUILD_TEST_FRAMEWORK OR USE_UNIT_TEST OR USE_PYTHON_TEST OR USE_FUZZ_TEST OR USE_CSV_OPS_TEST OR USE_INFRA_TEST OR USE_ALL_TEST) +if(BUILD_TEST_FRAMEWORK OR USE_UNIT_TEST OR USE_PYTHON_TEST OR USE_FUZZ_TEST OR USE_CSV_OPS_TEST OR USE_INFRA_TEST OR USE_PROF_TEST OR USE_ALL_TEST) if(USE_FUZZ_TEST OR USE_ALL_TEST) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage") endif() diff --git a/scripts/build.sh b/scripts/build.sh index 2542c458..76e0083d 100644 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -34,7 +34,7 @@ VERSION="8.0.0" LOG_PATH="/var/log/cann_atb_log/" LOG_NAME="cann_atb_install.log" -BUILD_OPTION_LIST="help default testframework unittest kernelunittest pythontest torchatbtest kernelpythontest csvopstest fuzztest infratest hitest alltest clean gendoc customizeops" +BUILD_OPTION_LIST="help default testframework unittest kernelunittest pythontest torchatbtest kernelpythontest csvopstest fuzztest infratest proftest hitest alltest clean gendoc customizeops" BUILD_CONFIGURE_LIST=("--verbose" "--use_cxx11_abi=0" "--use_cxx11_abi=1" "--asan" "--skip_build" "--csvopstest_options=.*" "--debug" "--clean-first" "--msdebug" "--mssanitizer" "--no-pybind" "--src-only") @@ -265,6 +265,25 @@ function fn_build_pybind11() git clone --branch v2.10.3 --depth 1 https://github.com/pybind/pybind11.git } +function fn_build_benchmark() +{ + if [ -d "$THIRD_PARTY_DIR/benchmark" ]; then + return 0 + fi + cd $THIRD_PARTY_DIR + git clone https://github.com/google/benchmark.git + cd benchmark + if [ "$USE_CXX11_ABI" == "ON" ] + then + sed -i '334 a add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)' CMakeLists.txt + else + sed -i '334 a add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)' CMakeLists.txt + fi + cmake -E make_directory "build" + cmake -E chdir "build" cmake -DBENCHMARK_DOWNLOAD_DEPENDENCIES=on -DCMAKE_BUILD_TYPE=Release -DBUILD_SHARED_LIBS=ON ../ + cmake --build "build" --config Release +} + function fn_build_secodefuzz() { FUZZ_DEST_PATH=$THIRD_PARTY_DIR/secodefuzz @@ -302,6 +321,7 @@ function fn_build_3rdparty_for_test() fi fn_build_googletest fn_build_stub + fn_build_benchmark } function fn_build_3rdparty_for_compile() @@ -529,6 +549,14 @@ function fn_run_kernel_cinterfacetest() $ATB_HOME_PATH/bin/atb_cinterface } +function fn_run_proftest() +{ + export_atb_env + export LD_LIBRARY_PATH=$PYTORCH_INSTALL_PATH/lib:$PYTORCH_NPU_INSTALL_PATH/lib:$CODE_ROOT/3rdparty/benchmark/build/src:$LD_LIBRARY_PATH + echo "run $ATB_HOME_PATH/bin/atb_proftest" + $ATB_HOME_PATH/bin/atb_proftest atb_proftest_baseline + $ATB_HOME_PATH/bin/atb_proftest +} function fn_run_fuzztest() { @@ -813,6 +841,12 @@ function fn_main() fn_build fn_run_infratest ;; + "proftest") + COMPILE_OPTIONS="${COMPILE_OPTIONS} -DUSE_PROF_TEST=ON" + fn_build_3rdparty_for_test + fn_build + fn_run_proftest + ;; "hitest") export_atb_hitest_env export CMAKE_CXX_COMPILER_LAUNCHER=hitestwrapper diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d15dc57a..c0b4e832 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -9,6 +9,9 @@ # add_subdirectory(framework) +if(USE_PROF_TEST) + add_subdirectory(proftest) +endif() if(USE_UNIT_TEST OR USE_ALL_TEST) add_subdirectory(unittest) add_subdirectory(cinterface) diff --git a/tests/proftest/CMakeLists.txt b/tests/proftest/CMakeLists.txt new file mode 100644 index 00000000..e5ff3f98 --- /dev/null +++ b/tests/proftest/CMakeLists.txt @@ -0,0 +1,61 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# + +file(GLOB_RECURSE UTILS_SOURCES "utils/src/*.cpp") +add_library(atb_proftest_utils STATIC ${UTILS_SOURCES}) +target_include_directories(atb_proftest_utils PRIVATE utils/include) +target_include_directories(atb_proftest_utils PRIVATE ${PROJECT_SOURCE_DIR}/) +target_include_directories(atb_proftest_utils PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty) +target_include_directories(atb_proftest_utils PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/include) +target_include_directories(atb_proftest_utils PRIVATE ${ASCEND_HOME_PATH}/include/) +target_link_directories(atb_proftest_utils PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/lib ${PROJECT_SOURCE_DIR}/3rdparty/benchmark/build/src ${ASCEND_HOME_PATH}/lib64/ ${ATB_HOME_PATH}/lib) +target_compile_options(atb_proftest_utils PRIVATE "-Wno-write-strings") +target_compile_options(atb_proftest_utils PRIVATE -Wno-sign-compare -Wno-narrowing) +set_target_properties(atb_proftest_utils PROPERTIES LINK_FLAGS "-Wl,-rpath-link,${PROJECT_SOURCE_DIR}/3rdparty/asdops/lib/") +target_link_libraries(atb_proftest_utils PRIVATE atb -lgtest -lgtest_main -lc_sec) + + +file(GLOB_RECURSE UTILS_SOURCES "layer_test_framework/*.cpp") +add_library(layer_test_framework STATIC ${UTILS_SOURCES}) +target_include_directories(layer_test_framework PRIVATE layer_test_framework) +target_include_directories(layer_test_framework PRIVATE layer_test_framework/core/include) +target_include_directories(layer_test_framework PRIVATE layer_test_framework/models) +target_include_directories(layer_test_framework PRIVATE layer_test_framework/operations) +target_include_directories(layer_test_framework PRIVATE ${PROJECT_SOURCE_DIR}/) +target_include_directories(layer_test_framework PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty) +target_include_directories(layer_test_framework PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/include) +target_include_directories(layer_test_framework PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nlohmannJson/include) +target_include_directories(layer_test_framework PRIVATE ${ASCEND_HOME_PATH}/include/) +target_link_directories(layer_test_framework PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/lib ${PROJECT_SOURCE_DIR}/3rdparty/benchmark/build/src ${ASCEND_HOME_PATH}/lib64/ ${ATB_HOME_PATH}/lib $ENV{ASCEND_TOOLKIT_HOME}/lib64) +target_compile_options(layer_test_framework PRIVATE "-Wno-write-strings") +target_compile_options(layer_test_framework PRIVATE -Wno-sign-compare -Wno-narrowing) +set_target_properties(layer_test_framework PROPERTIES LINK_FLAGS "-Wl,-rpath-link,${PROJECT_SOURCE_DIR}/3rdparty/asdops/lib/") +target_link_libraries(layer_test_framework PRIVATE atb -lgtest -lgtest_main -lc_sec -lopapi) + +find_package(Threads REQUIRED) +file(GLOB_RECURSE TEST_CASES "test_cases/*") +add_executable(atb_proftest main.cpp ${TEST_CASES}) +target_include_directories(atb_proftest PRIVATE utils/include) +target_include_directories(atb_proftest PRIVATE layer_test_framework) +target_include_directories(atb_proftest PRIVATE layer_test_framework/core/include) +target_include_directories(atb_proftest PRIVATE layer_test_framework/models) +target_include_directories(atb_proftest PRIVATE layer_test_framework/operations) +target_include_directories(atb_proftest PRIVATE ${PROJECT_SOURCE_DIR}/) +target_include_directories(atb_proftest PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty) +target_include_directories(atb_proftest PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/include) +target_include_directories(atb_proftest PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/benchmark/include) +target_include_directories(atb_proftest PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nlohmannJson/include) +target_include_directories(atb_proftest PRIVATE ${ASCEND_HOME_PATH}/include/) +target_link_directories(atb_proftest PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/lib ${PROJECT_SOURCE_DIR}/3rdparty/benchmark/build/src ${ASCEND_HOME_PATH}/lib64/ ${ATB_HOME_PATH}/lib $ENV{ASCEND_TOOLKIT_HOME}/lib64) +target_compile_options(atb_proftest PRIVATE "-Wno-write-strings") +target_compile_options(atb_proftest PRIVATE -Wno-sign-compare -Wno-narrowing) +set_target_properties(atb_proftest PROPERTIES LINK_FLAGS "-Wl,-rpath-link,${PROJECT_SOURCE_DIR}/3rdparty/asdops/lib/") +target_link_libraries(atb_proftest PRIVATE atb_proftest_utils layer_test_framework atb Threads::Threads benchmark -lgtest -lgtest_main -lc_sec -lopapi) +install(TARGETS atb_proftest DESTINATION bin) diff --git a/tests/proftest/layer_test_framework/core/base/context_factory.cpp b/tests/proftest/layer_test_framework/core/base/context_factory.cpp new file mode 100644 index 00000000..c00be48c --- /dev/null +++ b/tests/proftest/layer_test_framework/core/base/context_factory.cpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/base/context_factory.h" +#include +#include "atb_speed/log.h" +#include "atb_speed/utils/singleton.h" +#include "atb_speed/utils/config.h" + +namespace atb_speed { + +const int MAX_STREAM_NUM = 2; + +thread_local std::shared_ptr g_localContext; + +bool ContextFactory::cacheWorkspace_ = false; + +std::vector ContextFactory::GetSubStreams() +{ + static std::vector streams; + static bool initialized = false; + if (!initialized) { + aclInit(nullptr); + + for (int i = 0; i < MAX_STREAM_NUM; ++i) { + aclrtStream subStream; + + aclError ret = aclrtCreateStream(&subStream); + if (ret != ACL_ERROR_NONE) { + ATB_SPEED_LOG_ERROR("Failed to create aclrtStream: " << ret); + } + ret = aclrtSetStreamFailureMode(subStream, ACL_STOP_ON_FAILURE); + if (ret != 0) { + ATB_SPEED_LOG_ERROR("Failed to aclrtSetStreamFailureMode: " << ret); + } + streams.push_back(subStream); + } + initialized = true; + } + + return streams; +} + +std::shared_ptr ContextFactory::GetAtbContext(void *stream) +{ + if (g_localContext) { + ATB_SPEED_LOG_DEBUG("ContextFactory return localContext"); + return g_localContext; + } + ATB_SPEED_LOG_DEBUG("ContextFactory create atb::Context start"); + atb::Context *context = nullptr; + atb::Status st = atb::CreateContext(&context); + if (st != 0) { + ATB_SPEED_LOG_ERROR("ContextFactory create atb::Context fail"); + } + + if (context) { + context->SetExecuteStream(stream); + if (atb_speed::GetSingleton().IsUseTilingCopyStream()) { + ATB_SPEED_LOG_DEBUG("ContextFactory use tiling copy stream"); + context->SetAsyncTilingCopyStatus(true); + } else { + ATB_SPEED_LOG_DEBUG("ContextFactory not use tiling copy stream"); + } + } + + std::shared_ptr tmpLocalContext(context, [](atb::Context* context) {atb::DestroyContext(context);}); + g_localContext = tmpLocalContext; + + return g_localContext; +} + +void ContextFactory::FreeAtbContext() +{ + ATB_SPEED_LOG_DEBUG("ContextFactory FreeAtbContext start."); + if (!g_localContext) { + return; + } + + ATB_SPEED_LOG_DEBUG("ContextFactory localContext use_count: " << g_localContext.use_count()); + if (g_localContext.use_count() != 1) { + return; + } + ATB_SPEED_LOG_DEBUG("ContextFactory localContext reset."); + g_localContext.reset(); +} +} \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/base/event_manager.cpp b/tests/proftest/layer_test_framework/core/base/event_manager.cpp new file mode 100644 index 00000000..a4e2db73 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/base/event_manager.cpp @@ -0,0 +1,227 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include + +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/utils/check_util.h" +#include "atb_speed/log.h" +#include "atb_speed/base/event_manager.h" + +namespace atb_speed { + +thread_local std::vector> g_eventOperationsOfModel; + +#define CHECK_ACL_STATUS_RETURN_AND_LOG_IF_ERROR(ret, customErr, msg) \ + do { \ + if ((ret) != ACL_SUCCESS) { \ + ATB_SPEED_LOG_ERROR((msg) << " aclError = " << (ret)); \ + return (customErr); \ + } \ + } while (0) + +#define CHECK_EM_STATUS_ONLY_RETURN(status, retVal) \ + do { \ + if ((status) != EM_SUCCESS) { \ + return (retVal); \ + } \ + } while (0) + +#define CHECK_EM_STATUS_RETURN_AND_LOG_IF_ERROR(status, retVal, msg) \ + do { \ + if ((status) != EM_SUCCESS) { \ + ATB_SPEED_LOG_ERROR((msg) << " status=" << (status)); \ + return (retVal); \ + } \ + } while (0) + +EventManager& EventManager::GetInstance() +{ + static EventManager instance; + return instance; +} + +EventManager::EventManager() +{ + ATB_SPEED_LOG_DEBUG("EventManager created."); + uint32_t opWaitTimeout = 1080; + SetWaitOperationTimeout(opWaitTimeout); +} + +EventManager::~EventManager() +{ + std::lock_guard lk(queueMutex_); + for (auto& eventQueue : eventQueues_) { + while (!eventQueue.second.empty()) { + aclrtEvent event = eventQueue.second.front(); + eventQueue.second.pop(); + DestroyEvent(event); + } + } + ATB_SPEED_LOG_DEBUG("EventManager destroyed."); +} + +void EventManager::SetWaitOperationTimeout(uint32_t timeout) +{ + aclError ret = aclrtSetOpWaitTimeout(timeout); + if (ret != ACL_SUCCESS) { + ATB_SPEED_LOG_ERROR("aclrtSetOpWaitTimeout failed, aclError = " << ret); + } else { + ATB_SPEED_LOG_DEBUG("aclrtSetOpWaitTimeout end, set to " << timeout << " seconds."); + } +} + +EventManagerStatus EventManager::CreateEvent(aclrtEvent &event) const +{ + uint32_t flags = ACL_EVENT_SYNC; + aclError ret = aclrtCreateEventWithFlag(&event, flags); + CHECK_ACL_STATUS_RETURN_AND_LOG_IF_ERROR(ret, EM_CREATE_EVENT_FAILED, "aclrtCreateEventWithFlag failed,"); + ATB_SPEED_LOG_DEBUG("Event created, event = " << event); + + return EM_SUCCESS; +} + +EventManagerStatus EventManager::DestroyEvent(aclrtEvent event) const +{ + aclError ret = aclrtDestroyEvent(event); + CHECK_ACL_STATUS_RETURN_AND_LOG_IF_ERROR(ret, EM_DESTROY_EVENT_FAILED, "aclrtDestroyEvent failed,"); + ATB_SPEED_LOG_DEBUG("Event destroyed end, event = " << event); + + return EM_SUCCESS; +} + +EventManagerStatus EventManager::CreateAndPushEvent(aclrtEvent &event, const std::string &pipeKey) +{ + if (eventQueues_.find(pipeKey) == eventQueues_.end()) { + std::queue queue; + eventQueues_[pipeKey] = queue; + } + + EventManagerStatus ret = CreateEvent(event); + CHECK_EM_STATUS_ONLY_RETURN(ret, ret); + std::lock_guard lk(queueMutex_); + eventQueues_[pipeKey].push(event); + eventCount_.fetch_add(1, std::memory_order_relaxed); + eventCond_.notify_one(); + + ATB_SPEED_LOG_DEBUG("PushEvent: event = " << event + << ", Event pushed to queue, queueSize = " << eventQueues_[pipeKey].size() + << ", current eventCount = " << eventCount_.load()); + + return EM_SUCCESS; +} + +EventManagerStatus EventManager::PopEvent(aclrtEvent &event, const std::string &pipeKey) +{ + std::unique_lock lk(queueMutex_); + if (!eventCond_.wait_for(lk, std::chrono::microseconds(1), + [this, pipeKey] { return !eventQueues_[pipeKey].empty(); })) { + ATB_SPEED_LOG_DEBUG("PopEvent: Timeout waiting for event, current eventCount = " << eventCount_.load()); + return EM_POP_EVENT_TIMEOUT; + } + event = eventQueues_[pipeKey].front(); + eventQueues_[pipeKey].pop(); + eventCount_.fetch_sub(1, std::memory_order_relaxed); + lk.unlock(); + + ATB_SPEED_LOG_DEBUG("PopEvent: event = " << event + << ", current eventCount = " << eventCount_.load()); + + return EM_SUCCESS; +} + +atb::Status EventManager::EventInternal(EventAction eventAction, + EventType eventType, + atb::Operation*& op, + const std::string &pipeKey) +{ + atb::common::EventParam eventParam; + atb::common::EventParam::OperatorType opType; + std::string eventTypeStr; + + if (eventType == EventType::RECORD) { + opType = atb::common::EventParam::OperatorType::RECORD; + eventTypeStr = "RecordEvent"; + } else if (eventType == EventType::WAIT) { + opType = atb::common::EventParam::OperatorType::WAIT; + eventTypeStr = "WaitEvent"; + } else { + ATB_SPEED_LOG_ERROR("Invalid EventType: " << static_cast(eventType)); + return EM_INVALID_TYPE; + } + eventParam.operatorType = opType; + + aclrtEvent event = nullptr; + EventManagerStatus ret; + + if (eventAction == EventAction::PUSH) { + ret = CreateAndPushEvent(event, pipeKey); + CHECK_EM_STATUS_RETURN_AND_LOG_IF_ERROR(ret, EM_PUSH_EVENT_FAILED, + eventTypeStr + ": CreateAndPushEvent failed with error"); + ATB_SPEED_LOG_DEBUG(eventTypeStr << ": Pushed event, event: " << event); + if (!opsWithoutEvent_[pipeKey].empty()) { + auto& opParam = opsWithoutEvent_[pipeKey].front(); + opsWithoutEvent_[pipeKey].pop(); + opParam.second.event = eventQueues_[pipeKey].front(); + eventQueues_[pipeKey].pop(); + atb::UpdateOperationParam(opParam.first, opParam.second); + ATB_SPEED_LOG_DEBUG(eventTypeStr << ": Popped event, event: " << event); + } + } else if (eventAction == EventAction::POP) { + ret = PopEvent(event, pipeKey); + if (ret == EM_POP_EVENT_TIMEOUT) { + ATB_SPEED_LOG_DEBUG(eventTypeStr << ": Popped event, event: time out"); + atb::common::EventParam eventParamTmp; + if (atb::CreateOperation(eventParamTmp, &op) != atb::NO_ERROR) { + return EM_OPERATION_CREATION_FAILED; + } + opsWithoutEvent_[pipeKey].push(std::make_pair(op, eventParam)); + return atb::NO_ERROR; + } + ATB_SPEED_LOG_DEBUG(eventTypeStr << ": Popped event, event: " << event); + } else { + return EM_INVALID_ACTION; + } + + eventParam.event = event; + + if (atb::CreateOperation(eventParam, &op) != atb::NO_ERROR) { + return EM_OPERATION_CREATION_FAILED; + } + + g_eventOperationsOfModel.push_back(std::make_pair(op, eventParam)); + + return atb::NO_ERROR; +} + +atb::Status EventManager::RecordEvent(atb::Operation*& op, EventAction eventAction, const std::string &pipeKey) +{ + return EventInternal(eventAction, EventType::RECORD, op, pipeKey); +} + +atb::Status EventManager::WaitEvent(atb::Operation*& op, EventAction eventAction, const std::string &pipeKey) +{ + return EventInternal(eventAction, EventType::WAIT, op, pipeKey); +} + +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/base/external_comm_manager.cpp b/tests/proftest/layer_test_framework/core/base/external_comm_manager.cpp new file mode 100644 index 00000000..9311d71b --- /dev/null +++ b/tests/proftest/layer_test_framework/core/base/external_comm_manager.cpp @@ -0,0 +1,265 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "securec.h" +#include "atb_speed/base/external_comm_manager.h" + +namespace atb_speed { + +CommInfo::~CommInfo() +{ + ATB_SPEED_LOG_DEBUG("External Comm Manager: CommInfo [" + << std::hash{}(this) << "] destruction starts."); + if (this->hcclComm_ != nullptr) { + auto ret = HcclCommDestroy(this->hcclComm_); + if (ret != HCCL_SUCCESS) { + ATB_SPEED_LOG_ERROR("External Comm Manager: Call `HcclCommDestroy` API from CANN " + << "to destroy hccl communication group failed. " + << "Error code: " << ret << ". " + << "Check the default log path at $HOME/ascecnd/log for more details. "); + } + } + this->hcclComm_ = nullptr; +} + +std::string CommInfo::ToString() const +{ + std::stringstream ss; + ss << "Cache Addr[" << this << "] cacheId_: " << cacheId_ + << ", subCommRankId_: " << subCommRankId_ + << ", rankIds_: " << rankIds_ + << ", bufferSize_: " << bufferSize_ + << ", backend_: " << backend_ + << ", hcclComm_: " << hcclComm_ + << ", streamId_: " << streamId_; + return ss.str(); +} + +bool AreVectorsEqual(const std::vector &rankIdsA, const std::vector &rankIdsB) +{ + if (rankIdsA.size() != rankIdsB.size()) { + return false; + } + for (size_t i = 0; i < rankIdsA.size(); i++) { + if (rankIdsA.at(i) != rankIdsB.at(i)) { + return false; + } + } + return true; +} + +void ExternalCommManager::Init(uint32_t worldSize, uint32_t subCommRankId, + std::string backend, std::string rankTableFile, uint32_t streamId) +{ + ATB_SPEED_LOG_DEBUG("External Comm Manager: try to create global comm with " + << "worldSize " << worldSize << ", subCommRankId " << subCommRankId + << ", backend " << backend << ", rankTableFile " << rankTableFile + ); + + if (this->globalComm_ != nullptr) { + ATB_SPEED_LOG_DEBUG("External Comm Manager: A global communication group is already created, " + << "so the creation process will be skipped."); + return; + } + + this->worldSize_ = worldSize; + this->rank_ = subCommRankId; + this->rankTableFile_ = rankTableFile; + + std::vector rankIds = {}; + for (uint32_t id = 0; id < worldSize; id++) { + rankIds.push_back(id); + } + std::shared_ptr commInfo = std::make_shared(); + commInfo->cacheId_ = this->commInfoCache_.size(); + commInfo->subCommRankId_ = subCommRankId; + commInfo->rankIds_ = rankIds; + commInfo->backend_ = backend; + commInfo->streamId_ = streamId; + + std::string commDomain = ""; + if ((backend == HCCL || this->rankTableFile_ != "") && rankIds.size() > 1) { + commDomain = GetHcclGlobalCommDomain(commInfo); + } else if ((backend == LCCL || backend == LCOC) && rankIds.size() > 1) { + commDomain = GetCommDomainFromCache(rankIds, backend, 200, streamId); // 200: buffer size + if (commDomain == "") { commDomain = GetSelfAssignedCommDomain(commInfo, 0); } + } + ATB_SPEED_LOG_DEBUG("External Comm Manager: Add [" << commDomain << "] to cache."); + this->commInfoCache_[commDomain] = commInfo; +} + +void ExternalCommManager::SetLcclCommDomainRange(int32_t lowerBound, int32_t upperBound) +{ + this->lcclCommDomainLowerBound_ = lowerBound; + this->lcclCommDomainUpperBound_ = upperBound; +} + +std::string ExternalCommManager::GetCommDomain(uint32_t groupId, const std::vector &rankIds, + uint32_t subCommRankId, std::string backend, uint32_t bufferSize, uint32_t streamId, bool enableReuse) +{ + ATB_SPEED_LOG_DEBUG("External Comm Manager: try to create comm with rankIds " << rankIds + << ", subCommRankId " << subCommRankId << ", backend: " << backend << ", bufferSize " << bufferSize + << ", streamId " << streamId); + + std::string commDomain = ""; + + if (rankIds.size() <= 1) { return commDomain; } + + if (enableReuse) { + ATB_SPEED_LOG_DEBUG("External Comm Manager: try to reuse communication group from cache."); + commDomain = GetCommDomainFromCache(rankIds, backend, bufferSize, streamId); + if (commDomain != "") { + return commDomain; + } + } + + std::shared_ptr commInfo = std::make_shared(); + commInfo->cacheId_ = this->commInfoCache_.size(); + commInfo->subCommRankId_ = subCommRankId; + commInfo->rankIds_ = rankIds; + commInfo->backend_ = backend; + commInfo->bufferSize_ = bufferSize; + commInfo->streamId_ = streamId; + commInfo->enableReuse_ = enableReuse; + if ((backend == LCCL || backend == LCOC) && rankIds.size() > 1) { + commDomain = GetSelfAssignedCommDomain(commInfo, groupId); + } else if (backend == HCCL && rankIds.size() > 1) { + commDomain = GetHcclSubCommDomain(commInfo, groupId); + } + this->commInfoCache_[commDomain] = commInfo; + ATB_SPEED_LOG_DEBUG("External Comm Manager: Add [" << commDomain << "] to cache"); + return commDomain; +} + +std::string ExternalCommManager::GetCommDomainFromCache( + const std::vector &rankIds, std::string backend, uint32_t bufferSize, uint32_t streamId) +{ + std::map>::iterator it; + for (it = this->commInfoCache_.begin(); it != this->commInfoCache_.end(); it++) { + if (AreVectorsEqual(it->second->rankIds_, rankIds) && \ + it->second->backend_ == backend && it->second->bufferSize_ == bufferSize && \ + it->second->streamId_ == streamId && it->second->enableReuse_ + ) { + ATB_SPEED_LOG_DEBUG("External Comm Manager: Comm with rankIds " << rankIds + << ", bufferSize " << bufferSize << ", backend: " << backend + << ", streamId" << streamId << " hit. CommDomain [" << it->first << "] is reused."); + return it->first; + } + } + return ""; +} + +std::string ExternalCommManager::GetSelfAssignedCommDomain(std::shared_ptr &commInfo, uint32_t groupId) +{ + uint32_t commDomainInt = this->lcclCommDomainLowerBound_ + this->commDomainCounter_ + groupId; + if (commDomainInt >= this->lcclCommDomainUpperBound_) { + std::stringstream ss; + ss << "External Comm Manager: Lccl commDomain exceeds the upper bound. " + << "Available commDomain range is [" << this->lcclCommDomainLowerBound_ + << ", " << this->lcclCommDomainUpperBound_ << "]. " + << "The range of the communication domain is determinded by `num_lccl_comm_shards` " + << "and `lccl_comm_shard_id`. Please review initializaion parameters " + << "of the `GeneratorTorch` object."; + throw std::runtime_error(ss.str()); + } + std::string commDomain = std::to_string(commDomainInt); + this->commDomainCounter_ = this->commDomainCounter_ + ceil(this->worldSize_ / commInfo->rankIds_.size()); + ATB_SPEED_LOG_DEBUG("External Comm Manager: commDomainCounter_ update to " << this->commDomainCounter_); + return commDomain; +} + +std::string ExternalCommManager::GetHcclGlobalCommDomain(std::shared_ptr &commInfo) +{ + ATB_SPEED_LOG_DEBUG("GetHcclGlobalCommDomain start."); + std::string commDomain = ""; + if (this->rankTableFile_ != "") { + char commName[128] = {}; // 128: max commName length + this->globalComm_ = atb::Comm::CreateHcclCommByRankTableFile(commInfo->subCommRankId_, this->worldSize_, + this->rankTableFile_.data(), commName); + if (this->globalComm_ == nullptr) { + throw std::runtime_error("External Comm Manager: Create the hccl communication group failed. " \ + "export ASDOPS_LOG_LEVEL=ERROR, export ASDOPS_LOG_TO_STDOUT=1 to see more details. " \ + "Default log path is $HOME/atb/log. "); + } + commInfo->hcclComm_ = this->globalComm_; + char hcclCommName[128] = {}; + HcclGetCommName(this->globalComm_, hcclCommName); + commDomain = std::string(hcclCommName); + } else { + // There is only one global commonDomain. Thus, group Id is 0. + commDomain = GetSelfAssignedCommDomain(commInfo, 0); + } + ATB_SPEED_LOG_DEBUG("GetHcclGlobalCommDomain end."); + + return commDomain; +} + +std::string ExternalCommManager::GetHcclSubCommDomain(std::shared_ptr &commInfo, uint32_t groupId) +{ + ATB_SPEED_LOG_DEBUG("GetHcclSubCommDomain start."); + std::string commDomain = ""; + if (this->globalComm_ != nullptr) { + HcclComm hcclComm; + HcclCommConfig config; + HcclCommConfigInit(&config); + config.hcclBufferSize = commInfo->bufferSize_; + std::vector tempRankIds = {}; + for (auto item : commInfo->rankIds_) { tempRankIds.push_back(item); } + auto ret = HcclCreateSubCommConfig(&this->globalComm_, tempRankIds.size(), tempRankIds.data(), + commInfo->cacheId_, commInfo->subCommRankId_, &config, &hcclComm); + if (hcclComm == nullptr) { + ATB_SPEED_LOG_ERROR("External Comm Manager: Call `HcclCreateSubCommConfig` API from CANN " + << "to create the hccl communication group failed. " + << "Error code: " << ret << ". " + << "Check the default log path at $HOME/ascecnd/log for more details. "); + } + commInfo->hcclComm_ = hcclComm; + char hcclCommName[128] = {}; + HcclGetCommName(hcclComm, hcclCommName); + commDomain = std::string(hcclCommName); + } else { + commDomain = GetSelfAssignedCommDomain(commInfo, groupId); + } + ATB_SPEED_LOG_DEBUG("GetHcclSubCommDomain end."); + return commDomain; +} + +HcclComm ExternalCommManager::GetCommPtr(std::string commDomain) +{ + if (commDomain == "") { return nullptr; } + auto it = this->commInfoCache_.find(commDomain); + if (it == this->commInfoCache_.end()) { + std::stringstream ss; + ss << "External Comm Manager: Comm domain[" << commDomain << "] not found in cache."; + throw std::out_of_range(ss.str()); + } + return it->second->hcclComm_; +} + +std::string ExternalCommManager::PrintCommInfo() +{ + std::stringstream ss; + ss << "External Comm Manager: Comm Info Cache Summary: Count " << this->commInfoCache_.size(); + std::map>::const_iterator it; + for (it = this->commInfoCache_.begin(); it != this->commInfoCache_.end(); it++) { + ss << " Comm domain[" << it->first << "] " << it->second->ToString(); + } + return ss.str(); +} + +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/base/model.cpp b/tests/proftest/layer_test_framework/core/base/model.cpp new file mode 100644 index 00000000..23096c0a --- /dev/null +++ b/tests/proftest/layer_test_framework/core/base/model.cpp @@ -0,0 +1,768 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "atb_speed/log.h" +#include "atb_speed/utils/config.h" +#include "atb_speed/utils/singleton.h" +#include "atb_speed/utils/tensor_util.h" +#include "atb_speed/utils/speed_probe.h" +#include "atb_speed/base/model.h" + +namespace atb_speed { +static std::atomic g_executeStatus(atb::NO_ERROR); +static std::atomic g_preExecuteStatus(atb::NO_ERROR); + +static bool IsTensorDimsEqual(const atb::Dims &left, const atb::Dims &other) +{ + if (left.dimNum != other.dimNum) { + return false; + } + + for (uint64_t i = 0; i < left.dimNum; ++i) { + if (left.dims[i] != other.dims[i]) { + return false; + } + } + + return true; +} + +std::string Model::Graph::ToString() const +{ + std::stringstream ss; + for (size_t i = 0; i < weightTensors.size(); ++i) { + ss << "weightTensors[" << i << "]:" << &weightTensors.at(i) << " " + << TensorUtil::TensorToString(weightTensors.at(i)) << std::endl; + } + for (size_t i = 0; i < inTensors.size(); ++i) { + ss << "inTensors[" << i << "]:" << &inTensors.at(i) << " " << TensorUtil::TensorToString(inTensors.at(i)) + << std::endl; + } + for (size_t i = 0; i < outTensors.size(); ++i) { + ss << "outTensors[" << i << "]:" << &outTensors.at(i) << " " << TensorUtil::TensorToString(outTensors.at(i)) + << std::endl; + } + for (size_t i = 0; i < internalTensors.size(); ++i) { + ss << "internalTensors[" << i << "]:" << &internalTensors.at(i) << " " + << TensorUtil::TensorToString(internalTensors.at(i)) << std::endl; + } + ss << "nodes:" << nodes.size() << std::endl; + + for (size_t i = 0; i < nodes.size(); ++i) { + auto &node = nodes.at(i); + ss << "node[" << i << "] operation:" << node.operation.get() << ", operationName:" << node.operation->GetName() + << std::endl; + for (auto tensorIt : node.inTensors) { + ss << "node[" << i << "] inTensor:" << tensorIt << " " << TensorUtil::TensorToString(*tensorIt) + << std::endl; + } + for (auto tensorIt : node.outTensors) { + ss << "node[" << i << "] outTensor:" << tensorIt << " " << TensorUtil::TensorToString(*tensorIt) + << std::endl; + } + } + return ss.str(); +} + +atb::Status Model::ParseParam(const std::string ¶m) +{ + (void)param; + return atb::NO_ERROR; +} + +atb::Status Model::BindParamHostTensor(uint32_t nodeId) +{ + (void)nodeId; + return atb::NO_ERROR; +} + +void Model::Graph::Init() +{ + for (size_t i = 0; i < nodes.size(); i++) { + auto &node = nodes.at(i); + node.variantPack.inTensors.reserve(node.inTensors.size()); + node.variantPack.inTensors.resize(node.inTensors.size()); + node.variantPack.outTensors.reserve(node.outTensors.size()); + node.variantPack.outTensors.resize(node.outTensors.size()); + } + InitTensorType(); + InitTensorMaxNodeMap(); +} + +void Model::Graph::InitTensorType() +{ + for (auto &node : nodes) { + node.inTensorTypes.reserve(node.inTensors.size()); + node.inTensorTypes.resize(node.inTensors.size()); + node.outTensorTypes.reserve(node.outTensors.size()); + node.outTensorTypes.resize(node.outTensors.size()); + for (size_t i = 0; i < node.inTensors.size(); ++i) { + node.inTensorTypes.at(i) = + IsInternalTensor(node.inTensors.at(i)) ? + Model::TensorType::INTERMEDIATE_TENSOR : Model::TensorType::NOT_INTERMEDIATE_TENSOR; + } + for (size_t i = 0; i < node.outTensors.size(); ++i) { + node.outTensorTypes.at(i) = + IsInternalTensor(node.outTensors.at(i)) ? + Model::TensorType::INTERMEDIATE_TENSOR : Model::TensorType::NOT_INTERMEDIATE_TENSOR; + } + } +} + +bool Model::Graph::IsInternalTensor(const atb::Tensor *tensor) +{ + for (auto &internalTensor : internalTensors) { + if (&internalTensor == tensor) { + return true; + } + } + + return false; +} + +void Model::Graph::InitTensorMaxNodeMap() +{ + std::map tensorMaxNodeIdMap; + maxNodeIdTensorMap.clear(); + + for (size_t i = 0; i < internalTensors.size(); ++i) { + atb::Tensor &internalTensor = internalTensors[i]; + uint64_t maxNodeId = 0; + uint64_t dependNodeCount = 0; + for (size_t nodeId = 0; nodeId < nodes.size(); ++nodeId) { + auto &node = nodes.at(nodeId); + for (auto inTensorIt : node.inTensors) { + if (&internalTensor == inTensorIt) { + maxNodeId = nodeId; + dependNodeCount++; + } + } + } + tensorMaxNodeIdMap[&internalTensor] = maxNodeId; + if (dependNodeCount == 0) { + ATB_SPEED_LOG_WARN("Runner graph internal tensor[" << i << "] dependNodeCount is 0"); + } + maxNodeIdTensorMap[maxNodeId].insert(&internalTensor); + } +} + +Model::Model(const std::string &modelName, const std::string ¶m) : modelName_(modelName), param_(param) +{ + currentDevId_ = 0; + aclrtGetDevice(¤tDevId_); + + CHECK_THROW(param_.empty(), "Model init failed, param is empty, please check."); + CHECK_THROW(param_.size() > MAX_PARAM_STRING_LENGTH, "Model init failed, param is too long, please check."); +} + +Model::~Model() {} + +int64_t Model::Init(GetWorkspaceFunc getWorkSpaceFunc, CreateTensorFromTensorDescFunc createTensorFromTensorDescFunc, + RunTaskFunc runTaskFunc) +{ + // ATB_OPERATION_EXECUTE_ASYNC: whether to enable operator execute pipeline + // 0 - disable, 1 - enable level-2 pipeline (default), 2 - enable level-3 pipeline + const char *envStr = std::getenv("ATB_OPERATION_EXECUTE_ASYNC"); + isUsePlanExecuteAsync_ = (envStr != nullptr && (std::string(envStr) == "2" || std::string(envStr) == "1")); + isUsePlanPreExecuteAsync_ = (envStr != nullptr && std::string(envStr) == "2"); + if (isUsePlanExecuteAsync_ && !runTaskFunc) { + std::thread thread = std::thread(std::bind(&Model::ThreadProcessTask, this)); + taskProcessThread_ = std::move(thread); + } + + ATB_SPEED_LOG_DEBUG(modelName_ << " new, isTaskQueueEnable:" << (runTaskFunc != nullptr) + << ", isUsePlanExecuteAsync:" << isUsePlanExecuteAsync_ << ", currentDevId:" << currentDevId_); + + getWorkSpaceFunc_ = getWorkSpaceFunc; + createTensorFromTensorDescFunc_ = createTensorFromTensorDescFunc; + runTaskFunc_ = runTaskFunc; + + int64_t atbStatus = BuildGraph(); + eventOps_.clear(); + for (auto& eventOp : g_eventOperationsOfModel) { + eventOps_.push_back(eventOp); + } + g_eventOperationsOfModel.clear(); + CHECK_THROW(atbStatus != atb::NO_ERROR, + "Build model graph failed. enable log: export ASDOPS_LOG_LEVEL=ERROR, export ASDOPS_LOG_TO_STDOUT=1 to find " + "the first error. For more details, see the MindIE official document."); + graph_.Init(); + ATB_SPEED_LOG_DEBUG(modelName_ << " init graph:\n" << graph_.ToString()); + return atbStatus; +} + +atb::Status Model::SkipEvent(bool isSkipEvent) +{ + atb::Status rt = atb::NO_ERROR; + if (isSkipEvent != isSkipEvent_) { + isSkipEvent_ = isSkipEvent; + atb::common::EventParam eventOpParam; + eventOpParam.operatorType = atb::common::EventParam::OperatorType::UNDEFINED; + for (auto& eventOp : eventOps_) { + if (!isSkipEvent_) { + rt = atb::UpdateOperationParam(eventOp.first, eventOp.second); + } else { + rt = atb::UpdateOperationParam(eventOp.first, eventOpParam); + } + if (rt != atb::NO_ERROR) { + return rt; + } + } + } + return rt; +} + +atb::Status Model::SetNodeStreamId(Node& node, uint32_t streamId) const +{ + node.streamId = streamId; + auto rt = atb::SetExecuteStreamId(node.operation.get(), streamId); + CHECK_THROW(rt != atb::NO_ERROR, "atb::SetExecuteStreamId fail: " << rt); + return rt; +} + +int64_t Model::SetWeight(const std::vector &weightTensors) +{ + if (graph_.weightTensors.size() != weightTensors.size()) { + ATB_SPEED_LOG_ERROR(modelName_ << " weightTensors.size:" << weightTensors.size() << " != " + << " graph.weightTensors.size:" << graph_.weightTensors.size()); + return atb::ERROR_INVALID_IN_TENSOR_NUM; + } + + graph_.weightTensors = weightTensors; + return atb::NO_ERROR; +} + +int64_t Model::SetKVCache(const std::vector &kCacheTensors, const std::vector &vCacheTensors) +{ + if (graph_.kCacheTensors.size() != kCacheTensors.size()) { + ATB_SPEED_LOG_ERROR(modelName_ << " kCacheTensors.size:" << kCacheTensors.size() << " != " + << " graph.kCacheTensors.size:" << graph_.kCacheTensors.size()); + return atb::ERROR_INVALID_IN_TENSOR_NUM; + } + + if (graph_.vCacheTensors.size() != vCacheTensors.size()) { + ATB_SPEED_LOG_ERROR(modelName_ << " vCacheTensors.size:" << vCacheTensors.size() << " != " + << " graph.vCacheTensors.size:" << graph_.vCacheTensors.size()); + return atb::ERROR_INVALID_IN_TENSOR_NUM; + } + + graph_.kCacheTensors = kCacheTensors; + graph_.vCacheTensors = vCacheTensors; + return atb::NO_ERROR; +} + +atb::Status Model::Execute(atb::Context *context, std::vector &inTensors, + std::vector &outTensors, const std::string ¶m) +{ + if (graph_.inTensors.size() != inTensors.size() || graph_.outTensors.size() != outTensors.size()) { + ATB_SPEED_LOG_ERROR(modelName_ << " graph.inTensors.size:" << graph_.inTensors.size() + << ", inTensors.size:" << inTensors.size() + << ", graph.outTensors.size:" << graph_.outTensors.size() + << ", outTensors.size:" << outTensors.size()); + return atb::ERROR_INVALID_GRAPH; + } + + ParseParam(param); + + ClearInternalTensors(); + for (auto& i : nodeOutTensors_) { + i.second.clear(); + } + allTaskFinish_ = false; + context_ = context; + graph_.inTensors = inTensors; + graph_.outTensors = outTensors; + ATB_SPEED_LOG_DEBUG(modelName_ << " execute start, executeCount:" << executeCount_ << ", graph:\n" + << graph_.ToString()); + + for (size_t nodeId = 0; nodeId < graph_.nodes.size(); ++nodeId) { + BuildNodeVariantPack(nodeId); + BindParamHostTensor(nodeId); + atb::Status st = ExecuteNode(nodeId); + if (st != 0) { + return st; + } + } + + if (atb_speed::SpeedProbe::IsReportModelTopoInfo(modelName_)) { + std::string modelTopo = GetModelTopoInfo(); + atb_speed::SpeedProbe::ReportModelTopoInfo(modelName_, modelTopo); + } + + WaitAsyncPlanExecuteFinish(); + + return atb::NO_ERROR; +} + +int64_t Model::UpdateWeightsPtr(void *newWeightsPtr, int64_t oldWeightIds) +{ + if (static_cast(oldWeightIds) >= graph_.weightTensors.size()) { + ATB_SPEED_LOG_ERROR(modelName_ << ", replace target weight idx is exceeds the graph weights size, idx: " + << oldWeightIds << ", graph weights size: " << graph_.weightTensors.size()); + return atb::ERROR_INVALID_PARAM; + } + graph_.weightTensors.at(oldWeightIds).deviceData = newWeightsPtr; + return atb::NO_ERROR; +} + +int64_t Model::SetWeightFormat(const uint64_t weightId) +{ + if (weightId >= graph_.weightTensors.size()) { + ATB_SPEED_LOG_ERROR(modelName_ << ", replace source weight idx is exceeds the weights size, idx: " + << weightId << ", weights size: " + << graph_.weightTensors.size()); + return atb::ERROR_INVALID_PARAM; + } + graph_.weightTensors.at(weightId).desc.format = ACL_FORMAT_FRACTAL_NZ; + return atb::NO_ERROR; +} + +void Model::BuildNodeOutTensorImpl( + int nodeId, atb_speed::Model::Node &node, atb::SVector& inTensorDescs) +{ + atb::SVector outTensorDescs; + outTensorDescs.reserve(node.operation->GetOutputNum()); + outTensorDescs.resize(node.operation->GetOutputNum()); + atb::Status st = node.operation->InferShape(inTensorDescs, outTensorDescs); + if (st != 0) { + ATB_SPEED_LOG_ERROR(modelName_ << " nodes[" << nodeId << "] " + << " infer shape fail, error code: " << st); + } + for (size_t i = 0; i < outTensorDescs.size(); ++i) { + ATB_SPEED_LOG_DEBUG(modelName_ << " nodes[" << nodeId << "] outTensorDescs[" << i + << "]:" << TensorUtil::TensorDescToString(outTensorDescs.at(i))); + } + + for (size_t i = 0; i < node.outTensors.size(); ++i) { + CHECK_THROW(node.outTensors.at(i) == nullptr, + modelName_ << " nodes[" << nodeId << "] " + << "outTensor " << i << "is NULL"); + node.variantPack.outTensors.at(i) = *node.outTensors.at(i); + if (node.outTensorTypes.at(i) == Model::TensorType::INTERMEDIATE_TENSOR) { + node.variantPack.outTensors.at(i) + = MallocInternalTensor(node.outTensors.at(i), nodeId, i, outTensorDescs.at(i)); + *node.outTensors.at(i) = node.variantPack.outTensors.at(i); + } + if (!TensorUtil::TensorDescEqual(node.variantPack.outTensors.at(i).desc, outTensorDescs.at(i))) { + ATB_SPEED_LOG_DEBUG(modelName_ << " nodes[" << nodeId << "] new outTensorDescs[" << i + << "]:" << TensorUtil::TensorDescToString(outTensorDescs.at(i)) + << ", node.variantPack.outTensors.at[" << i + << "].desc:" << TensorUtil::TensorDescToString(node.variantPack.outTensors.at(i).desc)); + } + } +} + +void Model::BuildNodeVariantPack(int nodeId) +{ + auto &node = graph_.nodes.at(nodeId); + + atb::SVector inTensorDescs; + inTensorDescs.reserve(node.variantPack.inTensors.size()); + inTensorDescs.resize(node.variantPack.inTensors.size()); + for (size_t i = 0; i < node.inTensors.size(); ++i) { + CHECK_THROW(node.inTensors.at(i) == nullptr, + modelName_ << " nodes[" << nodeId << "] " + << "inTensor " << i << "is NULL"); + node.variantPack.inTensors.at(i) = *node.inTensors.at(i); + inTensorDescs.at(i) = node.inTensors.at(i)->desc; + ATB_SPEED_LOG_DEBUG(modelName_ << " nodes[" << nodeId << "] inTensors[" << i + << "]:" << TensorUtil::TensorToString(node.variantPack.inTensors.at(i))); + } + + BuildNodeOutTensorImpl(nodeId, node, inTensorDescs); + + auto it = graph_.maxNodeIdTensorMap.find(nodeId); + if (it != graph_.maxNodeIdTensorMap.end()) { + for (auto tensorIt : it->second) { + FreeInternalTensor(tensorIt, nodeId); + } + } +} + +atb::Status Model::ExecuteNode(int nodeId) +{ + ExecuteNodeView(nodeId); + auto &node = graph_.nodes.at(nodeId); + if (g_preExecuteStatus == atb::ERROR_OUT_OF_DEVICE_MEMORY || g_executeStatus == atb::ERROR_OUT_OF_DEVICE_MEMORY) { + throw std::runtime_error("Npu out of memory, OOM"); + } + if (g_preExecuteStatus != atb::NO_ERROR || g_executeStatus != atb::NO_ERROR) { + std::stringstream ss; + ss << "Execute fail, enable log: export ASDOPS_LOG_LEVEL=ERROR, export ASDOPS_LOG_TO_STDOUT=1 to find " + "the first error. For more details, see the MindIE official document." + << std::endl; + ATB_SPEED_LOG_ERROR(ss.str(), ATB_MODELS_EXECUTION_FAILURE); + throw std::runtime_error(ss.str()); + } + atb::Status st = node.operation->Setup(node.variantPack, node.workspaceSize, context_); + if (st == atb::ERROR_OUT_OF_DEVICE_MEMORY) { + throw std::runtime_error("Npu out of memory, OOM"); + } + if (st != atb::NO_ERROR) { + std::stringstream ss; + ss << "Setup fail, enable log: export ASDOPS_LOG_LEVEL=ERROR, export ASDOPS_LOG_TO_STDOUT=1 to find the first " + "error. For more details, see the MindIE official document." + << std::endl; + throw std::runtime_error(ss.str()); + } + if (st != 0) { + ATB_SPEED_LOG_ERROR(modelName_ << " setup node[" << nodeId << "] fail, not call execute"); + return st; + } + + ATB_SPEED_LOG_DEBUG(modelName_ << " get node[" << nodeId << "] workspace size:" << node.workspaceSize); + + if (node.workspaceSize > 0) { + node.workspace = getWorkSpaceFunc_(node.workspaceSize, node.streamId); + } + + if (isUsePlanExecuteAsync_) { + ExecutePlanAsync(nodeId); + } else { + st = ExecutePlanSync(nodeId); + } + return st; +} + +void Model::ThreadProcessTask() +{ + ATB_SPEED_LOG_DEBUG(modelName_ << " thread process operations start"); + int ret = aclrtSetDevice(currentDevId_); + if (ret != 0) { + ATB_SPEED_LOG_ERROR("AsdRtDeviceSetCurrent fail, error:" << ret); + } + + size_t processTaskCount = 0; + while (true) { + int nodeId = PopTask(); + if (nodeId == -1) { + ATB_SPEED_LOG_DEBUG(modelName_ << "placeholder task for sync communicate operation"); + } else { + atb::Status st = ExecutePlanSync(nodeId, !isUsePlanPreExecuteAsync_); + if (st != 0) { + allTaskFinish_ = true; + processTaskCount = 0; + return; + } + } + + processTaskCount++; + if (processTaskCount == graph_.nodes.size()) { + ATB_SPEED_LOG_DEBUG(modelName_ << " thread process all operations"); + processTaskCount = 0; + allTaskFinish_ = true; + } + } +} + +void Model::PushPreTask(int nodeId) +{ + auto task = [this, nodeId] { + atb::Status st = PreExecutePlanSync(nodeId); + if (st != atb::NO_ERROR) { + return; + } + // 推送任务给下一级流水 + if (runTaskFunc_ != nullptr) { + runTaskFunc_(modelName_ + graph_.nodes[nodeId].operation->GetName(), [this, nodeId]() { + ExecutePlanSync(nodeId, false); + return 0; + }); + } else { + PushTask(nodeId); + } + + if (size_t(nodeId + 1) == graph_.nodes.size() && runTaskFunc_ != nullptr) { + // 所有任务已经触发,标记结束 + allTaskFinish_ = true; + } + }; + ModelTaskExecutor::Instance().PushTask(currentDevId_, task); +} + +void Model::PushTask(int nodeId) +{ + std::unique_lock lock(mutex_); + taskQueue_.push(nodeId); + lock.unlock(); + cond_.notify_one(); +} + +int Model::PopTask() +{ + std::unique_lock lock(mutex_); + while (taskQueue_.empty()) { + cond_.wait(lock); + } + int nodeId = taskQueue_.front(); + taskQueue_.pop(); + return nodeId; +} + +atb::Status Model::ExecutePlanSync(int nodeId, bool doExecuteNormal) +{ + auto &node = graph_.nodes.at(nodeId); + auto oldType = context_->GetExecuteType(); + if (!doExecuteNormal) { + atb::VariantPack &variantPack = node.variantPack; + + ATB_SPEED_LOG_DEBUG(modelName_ << "execute node[" << nodeId << "] start"); + context_->SetExecuteType(atb::EXECUTE_LAUNCH); + atb::Status st = node.operation->Execute(variantPack, (uint8_t*)(node.workspace), node.workspaceSize, context_); + context_->SetExecuteType(oldType); + if (st != 0) { + ATB_SPEED_LOG_ERROR("Execute node[" << nodeId << "] fail, error code: " << st); + g_executeStatus = st; + } + return st; + } + atb::VariantPack &variantPack = node.variantPack; + + ATB_SPEED_LOG_DEBUG(modelName_ << "execute node[" << nodeId << "] start"); + context_->SetExecuteType(atb::EXECUTE_NORMAL); + atb::Status st = node.operation->Execute(variantPack, (uint8_t*)(node.workspace), node.workspaceSize, context_); + context_->SetExecuteType(oldType); + if (st != 0) { + ATB_SPEED_LOG_ERROR("Execute node[" << nodeId << "] fail, error code: " << st); + g_executeStatus = st; + } + return st; +} + +atb::Status Model::PreExecutePlanSync(int nodeId) +{ + auto &node = graph_.nodes.at(nodeId); + auto oldType = context_->GetExecuteType(); + atb::VariantPack &variantPack = node.variantPack; + context_->SetExecuteType(atb::EXECUTE_PRELAUNCH); + ATB_SPEED_LOG_DEBUG(modelName_ << "pre execute node[" << nodeId << "] start"); + atb::Status st = node.operation->Execute(variantPack, (uint8_t*)(node.workspace), node.workspaceSize, context_); + context_->SetExecuteType(oldType); + if (st != 0) { + ATB_SPEED_LOG_ERROR("pre execute node[" << nodeId << "] fail, error code: " << st); + g_preExecuteStatus = st; + } + return st; +} + +void Model::ExecutePlanAsync(int nodeId) +{ + if (executeCount_ == 0) { + ExecutePlanSync(nodeId); + executeCount_++; + // put a placeholder task for communicate operation + PushTask(-1); + return; + } + if (isUsePlanPreExecuteAsync_) { + PushPreTask(nodeId); + } else if (runTaskFunc_) { + runTaskFunc_(modelName_ + std::to_string(nodeId), [=]() { + ExecutePlanSync(nodeId); + return 0; + }); + } else { + PushTask(nodeId); + } +} + +void Model::WaitAsyncPlanExecuteFinish() +{ + if (!isUsePlanExecuteAsync_) { + return; + } + + if (!isUsePlanPreExecuteAsync_ && runTaskFunc_ != nullptr) { + return; + } + + while (!allTaskFinish_) { + ; + } + return; +} + +void Model::ExecuteNodeView(int nodeId) +{ + auto &node = graph_.nodes.at(nodeId); + if (node.inTensorReshapeFuncs.size() > 0) { + for (int i = 0; i < int(node.inTensorReshapeFuncs.size()); i++) { + if (node.inTensorReshapeFuncs.at(i) != nullptr) { + node.inTensorReshapeFuncs.at(i)(node.inTensors.at(i)->desc.shape, + node.variantPack.inTensors.at(i).desc.shape); + } + } + } +} + +bool Model::IsTensorDescEqual(const atb::TensorDesc &tensorDesc, const atb::Tensor &atbTensor) const +{ + return atbTensor.desc.dtype == tensorDesc.dtype && atbTensor.desc.format == tensorDesc.format && + IsTensorDimsEqual(atbTensor.desc.shape, tensorDesc.shape); +} + +void Model::ClearInternalTensors() +{ + for (auto& i : internalTensors_) { + i.second.clear(); + } +} + +atb::Tensor Model::MallocInternalTensor(atb::Tensor* outTensor, size_t nodeId, size_t outTensorId, + const atb::TensorDesc &tensorDesc) +{ + auto key = graph_.nodes[nodeId].streamId; + if (nodeOutTensors_.count(key) == 0) { + std::vector emptyOuts; + nodeOutTensors_[key] = emptyOuts; + } + if (internalTensors_.count(key) == 0) { + std::vector> emptyInte; + internalTensors_[key] = emptyInte; + } + if (GetSingleton().IsLayerInternalTensorReuse()) { + std::vector::iterator iter = + std::find(nodeOutTensors_[key].begin(), nodeOutTensors_[key].end(), outTensor); + if (iter != nodeOutTensors_[key].end()) { + ATB_SPEED_LOG_DEBUG(modelName_ << " nodeId: " << nodeId << ", out tensor id: " + << outTensorId << " write inplace"); + return **iter; + } + for (auto &it : internalTensors_[key]) { + if (it.second) { // Tensor被使用中,不能被分配其他Op + continue; + } + + if (IsTensorDescEqual(tensorDesc, it.first)) { + it.second = true; + ATB_SPEED_LOG_DEBUG(modelName_ << " use old internal tensor"); + return it.first; + } + } + } + + ATB_SPEED_LOG_DEBUG(modelName_ << " create internal tensor, node[" + << nodeId << "], outTensor[" << outTensorId << "]"); + atb::Tensor newTensor = createTensorFromTensorDescFunc_(tensorDesc); + internalTensors_[key].push_back(std::make_pair(newTensor, true)); + nodeOutTensors_[key].push_back(outTensor); + return newTensor; +} + +void Model::FreeInternalTensor(const atb::Tensor *tensorDeviceData, int nodeId) +{ + auto key = graph_.nodes[nodeId].streamId; + if (GetSingleton().IsLayerInternalTensorReuse()) { + for (auto &it : internalTensors_[key]) { + if (it.first.deviceData == tensorDeviceData->deviceData) { + it.second = false; // Tensor被释放,可以被后来者使用 + ATB_SPEED_LOG_DEBUG(modelName_ << " free internal tensor"); + break; + } + } + } +} + +void Model::GetModelTensorNameList(nlohmann::json &modelJson, std::map &tensorNameMap) +{ + std::string tensorName; + for (size_t i = 0; i < graph_.weightTensors.size(); i++) { + tensorName = modelName_ + "_weight_" + std::to_string(i); + modelJson["weightTensors"].emplace_back(tensorName); + atb::Tensor &weightTensor = graph_.weightTensors[i]; + tensorNameMap[&weightTensor] = tensorName; + } + + for (size_t i = 0; i < graph_.inTensors.size(); i++) { + tensorName = modelName_ + "_input_" + std::to_string(i); + modelJson["inTensors"].emplace_back(tensorName); + atb::Tensor &inTensor = graph_.inTensors[i]; + tensorNameMap[&inTensor] = tensorName; + } + + for (size_t i = 0; i < graph_.outTensors.size(); i++) { + tensorName = modelName_ + "_output_" + std::to_string(i); + modelJson["outTensors"].emplace_back(tensorName); + atb::Tensor &outTensor = graph_.outTensors[i]; + tensorNameMap[&outTensor] = tensorName; + } + + for (size_t i = 0; i < graph_.internalTensors.size(); i++) { + tensorName = modelName_ + "_internal_" + std::to_string(i); + modelJson["internalTensors"].emplace_back(tensorName); + atb::Tensor &internalTensor = graph_.internalTensors[i]; + tensorNameMap[&internalTensor] = tensorName; + } + + for (size_t i = 0; i < graph_.kCacheTensors.size(); i++) { + tensorName = modelName_ + "_kCache_" + std::to_string(i); + modelJson["kCacheTensors"].emplace_back(tensorName); + atb::Tensor &kCacheTensor = graph_.kCacheTensors[i]; + tensorNameMap[&kCacheTensor] = tensorName; + } + + for (size_t i = 0; i < graph_.vCacheTensors.size(); i++) { + tensorName = modelName_ + "_vCache_" + std::to_string(i); + modelJson["vCacheTensors"].emplace_back(tensorName); + atb::Tensor &vCacheTensor = graph_.vCacheTensors[i]; + tensorNameMap[&vCacheTensor] = tensorName; + } +} + +void Model::GetNodeTopoInfo(nlohmann::json &nodeJson, const Node &opNode, + const std::map tensorNameMap) const +{ + nodeJson["opName"] = opNode.operation->GetName(); + + for (auto inTensor : opNode.inTensors) { + auto it = tensorNameMap.find(inTensor); + if (it != tensorNameMap.end()) { + nodeJson["inTensors"].emplace_back(it->second); + } + } + + for (auto outTensor : opNode.outTensors) { + auto it = tensorNameMap.find(outTensor); + if (it != tensorNameMap.end()) { + nodeJson["outTensors"].emplace_back(it->second); + } + } +} + +std::string Model::GetModelTopoInfo() +{ + nlohmann::json modelJson; + modelJson["modelName"] = modelName_; + + std::map tensorNameMap; + GetModelTensorNameList(modelJson, tensorNameMap); + + for (size_t nodeId = 0; nodeId < graph_.nodes.size(); nodeId++) { + const auto &opNode = graph_.nodes.at(nodeId); + nlohmann::json nodeJson; + GetNodeTopoInfo(nodeJson, opNode, tensorNameMap); + modelJson["nodes"].emplace_back(nodeJson); + } + return modelJson.dump(); +} +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/base/context_factory.h b/tests/proftest/layer_test_framework/core/include/atb_speed/base/context_factory.h new file mode 100644 index 00000000..5879055c --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/base/context_factory.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_CONTEXT_FACTORY_H +#define ATB_SPEED_CONTEXT_FACTORY_H + +#include +#include + +namespace atb_speed { +class ContextFactory { +public: + static std::shared_ptr GetAtbContext(void *stream); + static std::vector GetSubStreams(); + static void FreeAtbContext(); + static bool cacheWorkspace_; +}; +} +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/base/event_manager.h b/tests/proftest/layer_test_framework/core/include/atb_speed/base/event_manager.h new file mode 100644 index 00000000..c004b7bc --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/base/event_manager.h @@ -0,0 +1,165 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_EVENT_MANAGER_H +#define ATB_SPEED_EVENT_MANAGER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "atb_speed/log.h" + +namespace atb_speed { +extern thread_local std::vector> g_eventOperationsOfModel; + +// 定义EventManager返回状态码: +enum EventManagerStatus { + EM_SUCCESS = 0, // 操作成功 + EM_CREATE_EVENT_FAILED = 1, // 使用 ACL 接口创建事件失败 + EM_PUSH_EVENT_FAILED = 2, // 推入事件队列失败 + EM_POP_EVENT_FAILED = 3, // 取出事件失败 + EM_POP_EVENT_TIMEOUT = 4, // 从事件队列中等待事件超时 + EM_INVALID_ACTION = 5, // 传入了无效的事件操作(既不是 PUSH 也不是 POP) + EM_INVALID_TYPE = 6, // 传入了无效的事件类型(既不是 RECORD 也不是 WAIT) + EM_DESTROY_EVENT_FAILED = 7, // 使用 ACL 接口销毁事件失败 + EM_OPERATION_CREATION_FAILED = 8, // 创建记录/等待操作节点失败(调用 atb::CreateOperation 失败) + EM_INVALID_KEY = 9, // 传入了不存在或不合法的事件映射键 + EM_UNKNOWN_ERROR = 10, // 未知错误,预留状态码 +}; + +// 定义事件类型: +enum class EventType { + UNDEFINED, // 不执行事件操作 + RECORD, // 记录事件操作 + WAIT, // 等待事件操作 +}; + +// 定义事件动作: +enum class EventAction { + PUSH, // 创建新事件并入队 + POP, // 从队列中取出已有事件 +}; + +class EventManager { +public: + // 单例访问接口,返回唯一的 EventManager 实例 + static EventManager& GetInstance(); + + EventManager(const EventManager&) = delete; + EventManager& operator=(const EventManager&) = delete; + + /** + * @brief 设置 ACL 的操作等待超时时间 + * @param timeout 超时时间(单位:秒) + */ + void SetWaitOperationTimeout(uint32_t timeout); + + //====================================================== + // 【多流接口】 + // - RecordEvent: 在多流场景下,通过 PUSH 操作创建记录事件并将其推入队列,或通过 POP 操作从事件队列中获取事件。 + // - WaitEvent: 在多流场景下,通过 PUSH 操作创建等待事件并将其推入队列,或通过 POP 操作从事件队列中获取事件。 + //------------------------------------------------------ + + /** + * @brief 在多流场景下,执行记录事件操作,该方法会根据 eventAction(PUSH/POP)执行相应的事件操作。 + * @param op 传出参数,创建的操作节点 + * @param eventAction 事件操作类型(PUSH 或 POP) + * @param pipeKey 事件管道的标识符(默认为 "default") + * @return 返回操作的状态码(atb::Status) + */ + atb::Status RecordEvent(atb::Operation*& op, EventAction eventAction, const std::string &pipeKey = "default"); + + /** + * @brief 在多流场景下,执行等待事件操作,该方法会根据 eventAction(PUSH/POP)执行相应的事件操作。 + * @param op 传出参数,创建的操作节点 + * @param eventAction 事件操作类型(PUSH 或 POP) + * @param pipeKey 事件管道的标识符(默认为 "default") + * @return 返回操作的状态码(atb::Status) + */ + atb::Status WaitEvent(atb::Operation*& op, EventAction eventAction, const std::string &pipeKey = "default"); + +private: + // 构造和析构函数设为 private,确保单例模式 + // 设置 ACL 的操作等待超时时间,默认 180 秒 + EventManager(); + ~EventManager(); + + /** + * @brief 使用 ACL 接口创建事件(默认为 ACL_EVENT_SYNC 类型) + * @param event 通过引用返回创建的事件 + * @return EM_SUCCESS 或 EM_CREATE_EVENT_FAILED + */ + EventManagerStatus CreateEvent(aclrtEvent &event) const; + + /** + * @brief 销毁事件 + * @param event 待销毁的事件 + * @return EM_SUCCESS 或 EM_DESTROY_EVENT_FAILED + */ + EventManagerStatus DestroyEvent(aclrtEvent event) const; + + /** + * @brief 创建事件并将其推入队列 + * @param event 通过引用返回创建的新 event,并将其入队 + * @param pipeKey 用于区分不同事件管道的键 + * @return EM_SUCCESS 或相应的错误码 + */ + EventManagerStatus CreateAndPushEvent(aclrtEvent &event, const std::string &pipeKey); + + /** + * @brief 从事件队列中取出一个事件 + * @param event 通过引用返回取出的事件 + * @param pipeKey 用于区分不同事件管道的键 + * @return EM_SUCCESS 或 EM_POP_EVENT_TIMEOUT + */ + EventManagerStatus PopEvent(aclrtEvent &event, const std::string &pipeKey); + + /** + * @brief 内部通用方法,用于处理记录/等待事件的公共逻辑 + * 该方法会根据 eventAction(PUSH/POP)和 eventType(RECORD/WAIT)获取 ACL 事件,并构造好事件参数(eventParam)。 + * @param eventAction 事件动作(PUSH 或 POP) + * @param eventType 事件类型(RECORD 或 WAIT) + * @param op 传出参数,创建的操作节点 + * @param pipeKey 事件管道的标识符 + * @return 返回操作的状态码(atb::Status) + */ + atb::Status EventInternal(EventAction eventAction, + EventType eventType, + atb::Operation*& op, + const std::string &pipeKey); + +private: + // 条件变量,用于在事件队列为空时阻塞等待,并在有新事件入队时通知等待线程 + std::condition_variable eventCond_; + // 事件队列:仅用于对外提供 push/pop 接口 + std::map> eventQueues_; + std::map>> opsWithoutEvent_; + // 保护 eventQueue_ 的互斥锁 + std::mutex queueMutex_; + // 记录当前队列中 event 的数量 + std::atomic eventCount_{0}; +}; +} // namespace atb_speed + +#endif // ATB_SPEED_EVENT_MANAGER_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/base/external_comm_manager.h b/tests/proftest/layer_test_framework/core/include/atb_speed/base/external_comm_manager.h new file mode 100644 index 00000000..5c0d8edc --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/base/external_comm_manager.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_EXTERNAL_COMM_MANAGER_H +#define ATB_SPEED_EXTERNAL_COMM_MANAGER_H + +#include +#include +#include "hccl/hccl.h" +#include "atb/comm.h" +#include "atb_speed/log.h" + +namespace atb_speed { + +const std::string LCCL = "lccl"; +const std::string HCCL = "hccl"; +const std::string LCOC = "lcoc"; + +/// A cache object contains information of a communication group +class CommInfo { +public: + ~CommInfo(); + + uint64_t cacheId_ = 0; + uint32_t subCommRankId_ = 0; + std::vector rankIds_ = {}; + std::string backend_ = ""; + HcclComm hcclComm_ = nullptr; + uint32_t bufferSize_ = 0; + uint32_t streamId_ = 0; + bool enableReuse_ = true; + + std::string ToString() const; +}; + +/// A class manages all the communication group (including commDomain and hcclComm ptr) +class ExternalCommManager { +public: + void Init(uint32_t worldSize, uint32_t subCommRankId, + std::string backend, std::string rankTableFile, uint32_t streamId); + + void SetLcclCommDomainRange(int32_t lowerBound, int32_t upperBound); + + std::string GetCommDomain(uint32_t groupId, const std::vector &rankIds, + uint32_t subCommRankId, std::string backend, uint32_t bufferSize, uint32_t streamId, + bool enableReuse = true); + + HcclComm GetCommPtr(std::string commDomain); + + std::string PrintCommInfo(); + + uint32_t worldSize_ = 0; + uint32_t rank_; + std::string rankTableFile_ = ""; + +private: + std::string GetCommDomainFromCache( + const std::vector &rankIds, std::string backend, uint32_t bufferSize, uint32_t streamId); + std::string GetSelfAssignedCommDomain(std::shared_ptr &commInfo, uint32_t groupId); + std::string GetHcclSubCommDomain(std::shared_ptr &commInfo, uint32_t groupId); + std::string GetHcclGlobalCommDomain(std::shared_ptr &commInfo); + + std::map> commInfoCache_ = {}; + HcclComm globalComm_ = nullptr; + uint32_t commDomainCounter_ = 0; + uint32_t lcclCommDomainLowerBound_ = 0; + uint32_t lcclCommDomainUpperBound_ = 0; +}; + +} // namespace atb_speed + +#endif // ATB_SPEED_EXTERNAL_COMM_MANAGER_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/base/hosttensor_binder.h b/tests/proftest/layer_test_framework/core/include/atb_speed/base/hosttensor_binder.h new file mode 100644 index 00000000..3ebb448d --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/base/hosttensor_binder.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_HOSTTENSOR_BINDER_H +#define ATB_SPEED_HOSTTENSOR_BINDER_H +#include +#include + +namespace atb_speed { +class HostTensorBinder { +public: + HostTensorBinder() = default; + virtual ~HostTensorBinder() = default; + virtual void ParseParam(const nlohmann::json ¶mJson) = 0; + virtual void BindTensor(atb::VariantPack &variantPack) = 0; +}; +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/base/model.h b/tests/proftest/layer_test_framework/core/include/atb_speed/base/model.h new file mode 100644 index 00000000..a95dce53 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/base/model.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_BASE_MODEL_H +#define ATB_SPEED_BASE_MODEL_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/utils/check_util.h" +#include "atb_speed/utils/ModelTaskExecutor.h" +#include "atb_speed/base/event_manager.h" + + +namespace atb_speed { +class Model { +public: + using ReshapeFunc = std::function; + using GetWorkspaceFunc = std::function; + using CreateTensorFromTensorDescFunc = std::function; + using Task = std::function; + using RunTaskFunc = std::function; + enum class TensorType { + INTERMEDIATE_TENSOR = 0, + NOT_INTERMEDIATE_TENSOR, + }; + + struct Node { + std::shared_ptr operation; + std::vector inTensors; + std::vector outTensors; + atb::VariantPack variantPack; + // std::vector torchTensors; + std::vector inTensorReshapeFuncs; + atb::SVector inTensorTypes; + atb::SVector outTensorTypes; + uint32_t streamId = 0; + uint64_t workspaceSize = 0; + void *workspace = nullptr; + }; + + struct Graph { + std::vector weightTensors; + std::vector kCacheTensors; + std::vector vCacheTensors; + std::vector inTensors; + std::vector outTensors; + std::vector internalTensors; + std::vector nodes; + std::map> maxNodeIdTensorMap; + void Init(); + std::string ToString() const; + + private: + void InitTensorType(); + bool IsInternalTensor(const atb::Tensor *tensor); + void InitTensorMaxNodeMap(); + }; + + Model(const std::string &modelName, const std::string ¶m); + virtual ~Model(); + int64_t Init(GetWorkspaceFunc getWorkSpaceFunc, CreateTensorFromTensorDescFunc createTensorFromTensorDescFunc, + RunTaskFunc runTaskFunc = nullptr); + + virtual uint32_t GetInputNum() const = 0; + virtual uint32_t GetOutputNum() const = 0; + virtual atb::Status InferShape(const std::vector &inTensorDescs, + std::vector &outTensorDescs) = 0; + + int64_t SetWeight(const std::vector &weightTensors); + int64_t SetWeightFormat(const uint64_t weightId); + int64_t SetKVCache(const std::vector &kCacheTensors, const std::vector &vCacheTensors); + atb::Status SkipEvent(bool isSkipEvent); + atb::Status SetNodeStreamId(Node& node, uint32_t streamId) const; + atb::Status Execute(atb::Context *context, std::vector &inTensors, + std::vector &outTensors, const std::string ¶m); + + int64_t UpdateWeightsPtr(void *newWeightsPtr, int64_t oldWeightIds); + +protected: + virtual int64_t BuildGraph() = 0; + virtual atb::Status ParseParam(const std::string ¶m); + virtual atb::Status BindParamHostTensor(uint32_t nodeId); + virtual void BuildNodeVariantPack(int nodeId); + +protected: + bool IsTensorDescEqual(const atb::TensorDesc &tensorDesc, const atb::Tensor &atbTensor) const; + void ExecuteNodeView(int nodeId); + atb::Status ExecuteNode(int nodeId); + void ThreadProcessTask(); + atb::Status ExecutePlanSync(int nodeId, bool doExecuteNormal = true); + void ExecutePlanAsync(int nodeId); + atb::Status PreExecutePlanSync(int nodeId); + void PushPreTask(int nodeId); + void PushTask(int nodeId); + int PopTask(); + void WaitAsyncPlanExecuteFinish(); + void ClearInternalTensors(); + atb::Tensor MallocInternalTensor(atb::Tensor* outTensor, size_t nodeId, size_t outTensorId, + const atb::TensorDesc &tensorDesc); + void FreeInternalTensor(const atb::Tensor *tensorDeviceData, int nodeId = 0); + void GetModelTensorNameList(nlohmann::json &modelJson, + std::map &tensorNameMap); + void GetNodeTopoInfo(nlohmann::json &nodeJson, const Node &opNode, + const std::map tensorNameMap) const; + std::string GetModelTopoInfo(); + void BuildNodeOutTensorImpl( + int nodeId, atb_speed::Model::Node &node, atb::SVector& inTensorDescs); + +protected: + GetWorkspaceFunc getWorkSpaceFunc_; + CreateTensorFromTensorDescFunc createTensorFromTensorDescFunc_; + RunTaskFunc runTaskFunc_ = nullptr; + std::string modelName_; + std::string param_; + Graph graph_; + + uint64_t executeCount_ = 0; + atb::Context *context_; + + bool isUsePlanExecuteAsync_ = false; + bool isUsePlanPreExecuteAsync_ = false; + bool isSkipEvent_ = false; + std::queue taskQueue_; + std::mutex mutex_; + std::condition_variable cond_; + std::thread taskProcessThread_; + std::atomic_bool allTaskFinish_; + int32_t currentDevId_ = 0; + std::map>> internalTensors_; + std::map> nodeOutTensors_; + std::vector> eventOps_; +}; +// Max length of param string +const size_t MAX_PARAM_STRING_LENGTH = 200000; +// Max value of tokenOffset, seqLen and qLen +const int MAX_PARAM_VALUE = 600000; +// Max value of vocab_size +const int64_t MAX_VOCAB_SIZE = 10000000; + +#define CHECK_THROW(condition, message) \ + do { \ + if (condition) { \ + std::stringstream ss; \ + ss << message << std::endl; \ + throw std::runtime_error(ss.str()); \ + } \ + } while (0) + +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/log.h b/tests/proftest/layer_test_framework/core/include/atb_speed/log.h new file mode 100644 index 00000000..a8d40d1c --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/log.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_LOG_H +#define ATB_SPEED_LOG_H + +#include +#include +#include + +#include + +#include "nlohmann/json.hpp" + +template +inline std::ostream& operator<<(std::ostream& os, const std::vector& vec) +{ + for (auto& el : vec) { + os << el << ','; + } + return os; +} + +#define ATB_SPEED_LOG(msg, ...) \ + do { \ + std::ostringstream oss; \ + oss << msg; \ + } while (0) + +#define ATB_SPEED_LOG_DEBUG(msg, ...) ATB_SPEED_LOG(msg, __VA_ARGS__) + +#define ATB_SPEED_LOG_INFO(msg, ...) ATB_SPEED_LOG(msg, __VA_ARGS__) + +#define ATB_SPEED_LOG_WARN(msg, ...) ATB_SPEED_LOG(msg, __VA_ARGS__) + +#define ATB_SPEED_LOG_ERROR(msg, ...) ATB_SPEED_LOG(msg, __VA_ARGS__) + +#define ATB_SPEED_LOG_FATAL(msg, ...) ATB_SPEED_LOG(msg, __VA_ARGS__) + +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/ModelTaskExecutor.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/ModelTaskExecutor.h new file mode 100644 index 00000000..18ee5364 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/ModelTaskExecutor.h @@ -0,0 +1,65 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_MODEL_TASK_EXECUTOR_H +#define ATB_SPEED_MODEL_TASK_EXECUTOR_H + +#include +#include +#include +#include +#include +#include + +#include "atb_speed/utils/TaskQueue.h" +#include "atb_speed/log.h" + +namespace atb_speed { + +class ModelTaskExecutor { +public: + struct Worker { + bool stop = false; + std::thread thread; + TaskQueue queue; + int deviceIdx = -1; + }; + +public: + static ModelTaskExecutor& Instance() + { + static ModelTaskExecutor instance; + return instance; + } + +public: + ~ModelTaskExecutor(); + + void PushTask(int idx, const Task &task); + +private: + ModelTaskExecutor() {} + + void WorkerThread(int workerId); + +private: + std::mutex mutex_; + std::deque workers_; + std::map idx2worker_; +}; +} // namespace atb_speed + +#endif // ATB_SPEED_MODEL_TASK_EXECUTOR_H diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/TaskQueue.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/TaskQueue.h new file mode 100644 index 00000000..b8d50a85 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/TaskQueue.h @@ -0,0 +1,40 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_TASK_QUEUE_H +#define ATB_SPEED_TASK_QUEUE_H + +#include +#include +#include +#include + +namespace atb_speed { + +using Task = std::function; + +class TaskQueue { +public: + void Enqueue(const Task &task); + Task Dequeue(); + +private: + std::mutex mutex_; + std::condition_variable cv_; + std::queue queue_; +}; +} // namespace atb_speed + +#endif // ATB_SPEED_TASK_QUEUE_H diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/check_util.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/check_util.h new file mode 100644 index 00000000..52b6d624 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/check_util.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_UTILS_CHECK_H +#define ATB_SPEED_UTILS_CHECK_H +#include +#include +#include +#include + +#include "nlohmann/json.hpp" + +namespace atb_speed { + +using Json = nlohmann::json; + +template +typename std::common_type::type CheckIntMulOverFlow(const T a, const U b) +{ + if (std::is_signed::value != std::is_signed::value) { + throw std::runtime_error("Multiplication between signed and unsigned integer not supported, it's not safe"); + } + using PromotedType = typename std::common_type::type; + if (a == 0 || b == 0) { + return 0; + } + + PromotedType pa = static_cast(a); + PromotedType pb = static_cast(b); + + if constexpr (std::is_signed::value) { + const PromotedType maxVal = std::numeric_limits::max(); + const PromotedType minVal = std::numeric_limits::min(); + if (pa > 0 && pb > 0) { + if (pa > maxVal / pb) { + throw std::overflow_error("Integer overflow detected."); + } + } else if (pa < 0 && pb < 0) { + if (pa < maxVal / pb) { + throw std::overflow_error("Integer overflow detected."); + } + } else if (pa > 0 && pb < 0) { + if (pa > minVal / pb) { + throw std::overflow_error("Integer overflow detected."); + } + } else if (pa < minVal / pb) { + throw std::overflow_error("Integer overflow detected."); + } + } else { + const PromotedType maxVal = std::numeric_limits::max(); + if (pa > maxVal / pb) { + throw std::overflow_error("Integer overflow detected."); + } + } + return pa * pb; +} +int CheckParamRange(const int &intParam, int min, int max); +int CheckNumHiddenLayersValid(const int &numHiddenLayers); +int CheckPositive(const int &intParam); +template +void CheckLinearParamsSufficient(const std::vector> &linearParam, \ + size_t numHiddenLayers, size_t thershold); +void CheckPackQuantParamsSufficient(const std::vector> &packQuantType, size_t numHiddenLayers); +void CheckLinearPackParamsSufficient(const std::vector> &linearPackType, size_t numHiddenLayers); +void CheckLinearHasBiasSufficient(const std::vector> &linearHasBias, size_t numHiddenLayers); +void CheckSkipLayerSet(const std::vector &skipLayerSet, size_t numHiddenLayers); +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/config.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/config.h new file mode 100644 index 00000000..eb8122f1 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/config.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_UTILS_CONFIG_H +#define ATB_SPEED_UTILS_CONFIG_H + +namespace atb_speed { +class Config { +public: + Config(); + ~Config(); + bool IsConvertNCHWToND() const; + bool IsTorchTensorFormatCast() const; + bool IsUseTilingCopyStream() const; + bool IsLayerInternalTensorReuse() const; + +private: + static bool IsEnable(const char *env, bool enable = false); + +private: + bool isConvertNCHWToND_ = false; + bool isTorchTensorFormatCast_ = true; + bool isUseTilingCopyStream_ = false; + bool isLayerInternalTensorReuse_ = false; +}; +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/file_system.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/file_system.h new file mode 100644 index 00000000..34187df0 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/file_system.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_UTILS_FILESYSTEM_H +#define ATB_SPEED_UTILS_FILESYSTEM_H +#include +#include + +namespace atb_speed { +class FileSystem { +public: + static bool Exists(const std::string &path); + static bool IsDir(const std::string &path); + static std::string Join(const std::vector &paths); + static int64_t FileSize(const std::string &filePath); + static std::string BaseName(const std::string &filePath); + static std::string DirName(const std::string &path); + static bool DeleteFile(const std::string &filePath); + static bool MakeDir(const std::string &dirPath, int mode); + static bool Makedirs(const std::string &dirPath, const mode_t mode); +}; +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/hccl_runner.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/hccl_runner.h new file mode 100644 index 00000000..219b8f82 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/hccl_runner.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_HCCL_RUNNER_H +#define ATB_SPEED_HCCL_RUNNER_H +#include +#include "share_memory.h" + +namespace atb_speed { +struct CommInitInfo { + int signal = 0; + HcclRootInfo hcclRootInfo = {}; + bool barrier[1]; // Flexible array member +}; + +class HcclRunner { +public: + explicit HcclRunner(int rank = 0, int rankSize = 0, int rankRoot = 0); + ~HcclRunner(); + HcclComm CreateHcclCommInMulitProcessByRootInfo(); + +protected: + int rank_ = 0; + int rankSize_ = 0; + int rankRoot_ = 0; + HcclRootInfo hcclRootInfo_ = {}; + +private: + bool CreateHcclRootInfo(); + void ShmGetHcclRootInfo(ShareMemory &shm, const CommInitInfo &shmInfo); + void ShmSetHcclRootInfo(ShareMemory &shm, CommInitInfo &shmInfo); + bool ShmBarrier(ShareMemory &shm, CommInitInfo &shmInfo); + void ShmSetReady(ShareMemory &shm, CommInitInfo &shmInfo) const; +}; +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/match.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/match.h new file mode 100644 index 00000000..4a3861b8 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/match.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_UTILS_STRINGS_MATCH_H +#define ATB_SPEED_UTILS_STRINGS_MATCH_H +#include + +namespace atb_speed { +bool StartsWith(const std::string &text, const std::string &prefix); +bool EndsWith(const std::string &text, const std::string &suffix); +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/model_factory.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/model_factory.h new file mode 100644 index 00000000..88007938 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/model_factory.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_UTILS_MODEL_FACTORY_H +#define ATB_SPEED_UTILS_MODEL_FACTORY_H + +#include +#include +#include +#include + +#include "atb_speed/base/model.h" + +namespace atb_speed { +using CreateModelFuncPtr = std::function(const std::string &)>; + +class ModelFactory { +public: + static bool Register(const std::string &modelName, CreateModelFuncPtr createModel); + static std::shared_ptr CreateInstance(const std::string &modelName, const std::string ¶m); +private: + static std::unordered_map &GetRegistryMap(); +}; + +#define MODEL_NAMESPACE_STRINGIFY(modelNameSpace) #modelNameSpace +#define REGISTER_MODEL(nameSpace, modelName) \ + struct Register##_##nameSpace##_##modelName { \ + inline Register##_##nameSpace##_##modelName() noexcept \ + { \ + ModelFactory::Register(MODEL_NAMESPACE_STRINGIFY(nameSpace##_##modelName), \ + [](const std::string ¶m) { return std::make_shared(param); }); \ + } \ + } static instance_##nameSpace##modelName +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/operation_factory.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/operation_factory.h new file mode 100644 index 00000000..6e854fd2 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/operation_factory.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_UTILS_OPERATION_FACTORY_H +#define ATB_SPEED_UTILS_OPERATION_FACTORY_H + +#include +#include +#include +#include + +#include "atb/operation.h" +#include "atb_speed/log.h" +#include "nlohmann/json.hpp" + + +namespace atb_speed { +using CreateOperationFuncPtr = std::function; + +class OperationFactory { +public: + static bool Register(const std::string &operationName, CreateOperationFuncPtr createOperation); + static atb::Operation *CreateOperation(const std::string &operationName, const nlohmann::json ¶m); +private: + static std::unordered_map &GetRegistryMap(); +}; + +#define OPERATION_NAMESPACE_STRINGIFY(operationNameSpace) #operationNameSpace +#define REGISTER_OPERATION(nameSpace, operationCreateFunc) \ + struct Register##_##nameSpace##_##operationCreateFunc { \ + inline Register##_##nameSpace##_##operationCreateFunc() \ + { \ + ATB_SPEED_LOG_DEBUG("register operation " << #nameSpace << "_" << #operationCreateFunc; \ + OperationFactory::Register(OPERATION_NAMESPACE_STRINGIFY(nameSpace##_##operationCreateFunc), \ + &(operationCreateFunc))); \ + } \ + } static instance_##nameSpace##operationCreateFunc +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/operation_util.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/operation_util.h new file mode 100644 index 00000000..7693aa1a --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/operation_util.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_UTILS_OPERATION_H +#define ATB_SPEED_UTILS_OPERATION_H +#include + +namespace atb_speed { +#define CREATE_OPERATION(param, operation) \ + do { \ + atb::Status atbStatus = atb::CreateOperation(param, operation); \ + if (atbStatus != atb::NO_ERROR) { \ + return atbStatus; \ + } \ + } while (0) + +#define CHECK_OPERATION_STATUS_RETURN(atbStatus) \ + do { \ + if ((atbStatus) != atb::NO_ERROR) { \ + return (atbStatus); \ + } \ + } while (0) + +#define CHECK_PARAM_LT(param, thershold) \ + do { \ + if ((param) >= (thershold)) { \ + ATB_SPEED_LOG_ERROR("param should be less than " << (thershold) << ", please check"); \ + return atb::ERROR_INVALID_PARAM; \ + } \ + } while (0) + +#define CHECK_PARAM_GT(param, thershold) \ + do { \ + if ((param) <= (thershold)) { \ + ATB_SPEED_LOG_ERROR("param should be greater than " << (thershold) << ", please check"); \ + return atb::ERROR_INVALID_PARAM; \ + } \ + } while (0) + +#define CHECK_PARAM_NE(param, value) \ + do { \ + if ((param) == (value)) { \ + ATB_SPEED_LOG_ERROR("param should not be equal to " << (value) << ", please check"); \ + return atb::ERROR_INVALID_PARAM; \ + } \ + } while (0) + +#define CHECK_TENSORDESC_DIMNUM_VALID(dimNum) \ + do { \ + if ((dimNum) > (8) || (dimNum) == (0) ) { \ + ATB_SPEED_LOG_ERROR("dimNum should be less or equal to 8 and cannot be 0, please check"); \ + return atb::ERROR_INVALID_PARAM; \ + } \ + } while (0) +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/share_memory.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/share_memory.h new file mode 100644 index 00000000..09331862 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/share_memory.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SHAREMEMORY_H +#define SHAREMEMORY_H + +#include +#include +#include + +class ShareMemory { +public: + ShareMemory(const std::string &name, uint32_t size); + ~ShareMemory(); + ShareMemory(const ShareMemory &other) = delete; + ShareMemory &operator=(const ShareMemory &other) = delete; + void *GetShm(); + void SemLock() const; + void SemUnLock() const; + +private: + void *CreateShareMemory(const std::string &name, uint32_t size); + void CleanUpShm(); + void CleanUpSem(); + +private: + std::string fullName_; + sem_t *sem_ = nullptr; + uint8_t *shareMemory_ = nullptr; + uint32_t memSize_ = 0; + int shmid_ = -1; +}; + +#endif diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/singleton.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/singleton.h new file mode 100644 index 00000000..cb402a92 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/singleton.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_UTILS_SINGLETON_H +#define ATB_SPEED_UTILS_SINGLETON_H + +namespace atb_speed { + +template T &GetThreadLocalSingleton() +{ + thread_local static T instance; + return instance; +} + +template T &GetSingleton() +{ + static T instance; + return instance; +} +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/speed_probe.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/speed_probe.h new file mode 100644 index 00000000..a653d2e8 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/speed_probe.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_PROBE_H +#define ATB_SPEED_PROBE_H + +#include +#include + +namespace atb_speed { + +class SpeedProbe { +public: + static bool IsReportModelTopoInfo(const std::string &modelName); + static void ReportModelTopoInfo(const std::string &modelName, const std::string &graph); +}; + +} + +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/statistic.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/statistic.h new file mode 100644 index 00000000..9063d703 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/statistic.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_UTILS_STATISTIC_H +#define ATB_SPEED_UTILS_STATISTIC_H +#include +#include + +namespace atb_speed { +struct Statistic { + uint64_t totalTime = 0; + uint64_t createTensorTime = 0; + uint64_t planSetupTime = 0; + uint64_t planAsyncTime = 0; + uint64_t planExecuteTime = 0; + uint64_t streamSyncTime = 0; + uint64_t tillingCopyTime = 0; + uint64_t getBestKernelTime = 0; + uint64_t kernelExecuteTime = 0; + uint64_t kernelCacheHitCount = 0; + uint64_t kernelCacheMissCount = 0; + uint64_t mallocTorchTensorSize = 0; + + std::string ToString() const; + void Reset(); +}; + +Statistic &GetStatistic(); +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/str_split.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/str_split.h new file mode 100644 index 00000000..baa84144 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/str_split.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_UTILS_STRINGS_STRSPLIT_H +#define ATB_SPEED_UTILS_STRINGS_STRSPLIT_H +#include + +namespace atb_speed { +std::string GetFuncNameAndNameSpace(const std::string &inputStr); +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/tensor_util.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/tensor_util.h new file mode 100644 index 00000000..67d87088 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/tensor_util.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_UTILS_TENSOR_UTIL_H +#define ATB_SPEED_UTILS_TENSOR_UTIL_H +#include +#include + +namespace atb_speed { +class TensorUtil { +public: + static std::string TensorToString(const atb::Tensor &tensor); + static std::string TensorDescToString(const atb::TensorDesc &tensorDesc); + static uint64_t GetTensorNumel(const atb::Tensor &tensor); + static uint64_t GetTensorNumel(const atb::TensorDesc &tensorDesc); + static bool TensorDescEqual(const atb::TensorDesc &tensorDescA, const atb::TensorDesc &tensorDescB); +}; +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/include/atb_speed/utils/timer.h b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/timer.h new file mode 100644 index 00000000..246317a6 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/include/atb_speed/utils/timer.h @@ -0,0 +1,35 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_UTILS_TIMER_H +#define ATB_SPEED_UTILS_TIMER_H +#include + +namespace atb_speed { +class Timer { +public: + Timer(); + ~Timer(); + uint64_t ElapsedMicroSecond(); + void Reset(); + +private: + uint64_t GetCurrentTimepoint() const; + +private: + uint64_t startTimepoint_ = 0; +}; +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/utils/ModelTaskExecutor.cpp b/tests/proftest/layer_test_framework/core/utils/ModelTaskExecutor.cpp new file mode 100644 index 00000000..a5d02fb7 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/utils/ModelTaskExecutor.cpp @@ -0,0 +1,68 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "atb_speed/utils/ModelTaskExecutor.h" + +namespace atb_speed { +ModelTaskExecutor::~ModelTaskExecutor() +{ + for (auto &worker : workers_) { + auto task = [&worker]() -> int { + worker.stop = true; + return 0; + }; + worker.queue.Enqueue(task); + worker.thread.join(); + } +} + +void ModelTaskExecutor::PushTask(int idx, const Task &task) +{ + auto it = idx2worker_.find(idx); + if (it == idx2worker_.end()) { + std::lock_guard guard(mutex_); + it = idx2worker_.find(idx); + if (it == idx2worker_.end()) { + uint32_t workerId = workers_.size(); + workers_.emplace_back(); + auto &worker = workers_[workerId]; + worker.deviceIdx = idx; + worker.thread = std::thread(&ModelTaskExecutor::WorkerThread, this, workerId); + it = idx2worker_.insert({idx, workerId}).first; + } + } + auto &worker = workers_[it->second]; + worker.queue.Enqueue(task); + return; +} + +void ModelTaskExecutor::WorkerThread(int workerId) +{ + ATB_SPEED_LOG_DEBUG("WorkerThread " << workerId << " start."); + auto &worker = workers_[workerId]; + int ret = aclrtSetDevice(worker.deviceIdx); + if (ret != 0) { + ATB_SPEED_LOG_ERROR("AsdRtDeviceSetCurrent fail, error:" << ret); + } + while (!worker.stop) { + auto task = worker.queue.Dequeue(); + task(); + } + ATB_SPEED_LOG_DEBUG("WorkerThread " << workerId << " end."); + return; +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/utils/TaskQueue.cpp b/tests/proftest/layer_test_framework/core/utils/TaskQueue.cpp new file mode 100644 index 00000000..09d3b914 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/utils/TaskQueue.cpp @@ -0,0 +1,36 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb_speed/utils/TaskQueue.h" +namespace atb_speed { +void TaskQueue::Enqueue(const Task &task) +{ + std::unique_lock lock(mutex_); + queue_.push(task); + lock.unlock(); + cv_.notify_one(); + return; +} + +Task TaskQueue::Dequeue() +{ + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] {return !queue_.empty();}); + auto task = queue_.front(); + queue_.pop(); + return task; +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/utils/check_util.cpp b/tests/proftest/layer_test_framework/core/utils/check_util.cpp new file mode 100644 index 00000000..bc8bf1af --- /dev/null +++ b/tests/proftest/layer_test_framework/core/utils/check_util.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/utils/check_util.h" + +#include +#include + +namespace atb_speed { +// Param Type Size +const size_t PACK_QUANT_TYPE_LENGTH = 2; +const size_t LINEAR_TYPE_LENGTH = 7; +const size_t LINEAR_BIAS_TYPE_LENGTH = 4; +const int MAX_NUM_HIDDEN_LAYER = 1000; + +static std::map> g_integerTypeMap = { + {"int32_t", {"2147483647", "-2147483648"}}, + {"uint32_t", {"4294967295", " "}}, + {"size_t", {"18446744073709551615", " "}}, +}; + + +int CheckParamRange(const int &intParam, int min, int max) +{ + if (intParam < min) { + std::stringstream ss; + ss << "This param must be a number greater or equal to " << min << ", please check." << std::endl; + throw std::runtime_error(ss.str()); + } + if (intParam > max) { + std::stringstream ss; + ss << "This param must be a number less or equal to " << max << ", please check." << std::endl; + throw std::runtime_error(ss.str()); + } + return intParam; +} + +int CheckNumHiddenLayersValid(const int &numHiddenLayers) +{ + return CheckParamRange(numHiddenLayers, 1, MAX_NUM_HIDDEN_LAYER); +} + +int CheckPositive(const int &intParam) +{ + if (intParam <= 0) { + std::stringstream ss; + ss << "This param must be a number greater than 0, please check." << std::endl; + throw std::runtime_error(ss.str()); + } + return intParam; +} + +template +void CheckLinearParamsSufficient(const std::vector> &linearParam, \ + size_t numHiddenLayers, size_t thershold) +{ + if (linearParam.size() != numHiddenLayers) { + std::stringstream ss; + ss << "The size of param must be equal to numHiddenLayers, please check." << std::endl; + throw std::runtime_error(ss.str()); + } + for (auto item : linearParam) { + if (item.size() != thershold) { + std::stringstream ss; + ss << "The size of vector within param must be equal to " << thershold <<" please check." << std::endl; + throw std::runtime_error(ss.str()); + } + } +} + +void CheckSkipLayerSet(const std::vector &skipLayerSet, size_t numHiddenLayers) +{ + if (skipLayerSet.size() >= numHiddenLayers) { + std::stringstream ss; + ss << "The size of skipLayerSet must be less than " << + numHiddenLayers << + " please check attn and mlp skipped_layers in plugin_params." << std::endl; + throw std::runtime_error(ss.str()); + } + + for (size_t layerId : skipLayerSet) { + if (layerId >= numHiddenLayers) { + std::stringstream ss; + ss << "The layer id must be greater than or equal to 0 and less than " << + numHiddenLayers << + " please check layer id in attn and mlp skipped_layers." << std::endl; + throw std::runtime_error(ss.str()); + } + } +} + +void CheckPackQuantParamsSufficient(const std::vector> &packQuantType, size_t numHiddenLayers) +{ + CheckLinearParamsSufficient(packQuantType, numHiddenLayers, PACK_QUANT_TYPE_LENGTH); +} + +void CheckLinearPackParamsSufficient(const std::vector> &linearPackType, size_t numHiddenLayers) +{ + CheckLinearParamsSufficient(linearPackType, numHiddenLayers, LINEAR_TYPE_LENGTH); +} + +void CheckLinearHasBiasSufficient(const std::vector> &linearHasBias, size_t numHiddenLayers) +{ + CheckLinearParamsSufficient(linearHasBias, numHiddenLayers, LINEAR_BIAS_TYPE_LENGTH); +} + +template void CheckLinearParamsSufficient(const std::vector> &linearParam, \ + size_t numHiddenLayers, size_t thershold); +template void CheckLinearParamsSufficient(const std::vector> &linearParam, \ + size_t numHiddenLayers, size_t thershold); +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/utils/config.cpp b/tests/proftest/layer_test_framework/core/utils/config.cpp new file mode 100644 index 00000000..d37a5670 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/utils/config.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/utils/config.h" +#include +#include +#include +#include +#include +#include "atb_speed/log.h" + +namespace atb_speed { +Config::Config() +{ + isConvertNCHWToND_ = true; + isTorchTensorFormatCast_ = true; + isUseTilingCopyStream_ = IsEnable("ATB_USE_TILING_COPY_STREAM"); + isLayerInternalTensorReuse_ = true; + ATB_SPEED_LOG_DEBUG(" \nIsConvertNCHWToND:" << isConvertNCHWToND_ + << "\nIsTorchTensorFormatCast:" << isTorchTensorFormatCast_ + << "\nIsLayerInternalTensorReuse:" << isLayerInternalTensorReuse_); +} + +Config::~Config() {} + +bool Config::IsEnable(const char *env, bool enable) +{ + const char *saveTensor = std::getenv(env); + if (!saveTensor) { + return enable; + } + return std::string(saveTensor) == "1"; +} + +bool Config::IsTorchTensorFormatCast() const { return isTorchTensorFormatCast_; }; + +bool Config::IsConvertNCHWToND() const { return isConvertNCHWToND_; } + +bool Config::IsUseTilingCopyStream() const { return isUseTilingCopyStream_; } + +bool Config::IsLayerInternalTensorReuse() const +{ + return isLayerInternalTensorReuse_; +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/utils/file_system.cpp b/tests/proftest/layer_test_framework/core/utils/file_system.cpp new file mode 100644 index 00000000..9af453ea --- /dev/null +++ b/tests/proftest/layer_test_framework/core/utils/file_system.cpp @@ -0,0 +1,129 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/utils/file_system.h" +#include +#include +#include +#include +#include + +namespace atb_speed { + +constexpr size_t MAX_PATH_LEN = 256; + +bool FileSystem::Exists(const std::string &path) +{ + struct stat st; + if (stat(path.c_str(), &st) < 0) { + return false; + } + return true; +} + +bool FileSystem::IsDir(const std::string &path) +{ + struct stat st; + if (stat(path.c_str(), &st) < 0) { + return false; + } + + return S_ISDIR(st.st_mode); +} + +std::string FileSystem::Join(const std::vector &paths) +{ + std::string retPath; + for (const auto &path : paths) { + if (retPath.empty()) { + retPath.append(path); + } else { + retPath.append("/" + path); + } + } + return retPath; +} + +int64_t FileSystem::FileSize(const std::string &filePath) +{ + struct stat st; + if (stat(filePath.c_str(), &st) < 0) { + return -1; + } + return st.st_size; +} + +std::string FileSystem::BaseName(const std::string &filePath) +{ + std::string fileName; + const char *str = strrchr(filePath.c_str(), '/'); + if (str) { + fileName = str + 1; + } else { + fileName = filePath; + } + return fileName; +} + +std::string FileSystem::DirName(const std::string &path) +{ + int32_t idx = path.size() - 1; + while (idx >= 0 && path[idx] == '/') { + idx--; + } + std::string sub = path.substr(0, idx); + const char *str = strrchr(sub.c_str(), '/'); + if (str == nullptr) { + return "."; + } + idx = str - sub.c_str() - 1; + while (idx >= 0 && path[idx] == '/') { + idx--; + } + if (idx < 0) { + return "/"; + } + return path.substr(0, idx + 1); +} + +bool FileSystem::DeleteFile(const std::string &filePath) +{ + int ret = remove(filePath.c_str()); + return ret == 0; +} + +bool FileSystem::MakeDir(const std::string &dirPath, int mode) +{ + int ret = mkdir(dirPath.c_str(), mode); + return ret == 0; +} + +bool FileSystem::Makedirs(const std::string &dirPath, const mode_t mode) +{ + int32_t offset = 0; + int32_t pathLen = dirPath.size(); + do { + const char *str = strchr(dirPath.c_str() + offset, '/'); + offset = (str == nullptr) ? pathLen : str - dirPath.c_str() + 1; + std::string curPath = dirPath.substr(0, offset); + if (!Exists(curPath)) { + if (!MakeDir(curPath, mode)) { + return false; + } + } + } while (offset != pathLen); + return true; +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/utils/hccl_runner.cpp b/tests/proftest/layer_test_framework/core/utils/hccl_runner.cpp new file mode 100644 index 00000000..7455ce97 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/utils/hccl_runner.cpp @@ -0,0 +1,144 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/utils/hccl_runner.h" +#include +#include "atb_speed/log.h" + +namespace atb_speed { +HcclRunner::HcclRunner(int rank, int rankSize, int rankRoot) + : rank_(rank), + rankSize_(rankSize), + rankRoot_(rankRoot) {} + +HcclRunner::~HcclRunner() +{ + ATB_SPEED_LOG_DEBUG("HcclRunner deconstruct"); +} + + +HcclComm HcclRunner::CreateHcclCommInMulitProcessByRootInfo() +{ + ATB_SPEED_LOG_DEBUG("HCCL Runner single server init "); + if (!CreateHcclRootInfo()) { + return nullptr; + } + + HcclComm newHcclComm = nullptr; + auto ret = HcclCommInitRootInfo(rankSize_, &hcclRootInfo_, rank_, &newHcclComm); + if (ret != HCCL_SUCCESS) { + ATB_SPEED_LOG_ERROR("HcclCommInitRootInfo fail, error:" << ret << ", rank:" << rank_ + << ", rankSize:" << rankSize_); + } + return newHcclComm; +} + +bool HcclRunner::CreateHcclRootInfo() +{ + std::string shmName = "hcclShareMem"; + ShareMemory shm(shmName, sizeof(atb_speed::CommInitInfo) + rankSize_ * sizeof(bool)); + auto *shmInfo = static_cast(shm.GetShm()); + if (!shmInfo) { + ATB_SPEED_LOG_ERROR("create share memory fail, rank:" << rank_); + return false; + } + + // 主进程通过HcclGetRootInfo获取到hcclRootInfo_(包含HostIP信息), 写到共享内存,其他进程读取RoortInfo + // 等所有的进程都准备好时,再一起往下执行CreateHcclComm + ATB_SPEED_LOG_DEBUG("create share memory success, rank:" << rank_); + if (rank_ == rankRoot_) { + auto ret = HcclGetRootInfo(&hcclRootInfo_); + if (ret != HCCL_SUCCESS) { + ATB_SPEED_LOG_ERROR("HcclGetRootInfo fail, error:" << ret << ", rank:" << rank_); + return false; + } + ATB_SPEED_LOG_DEBUG("HcclGetRootInfo success, write to share memory"); + ShmSetHcclRootInfo(shm, *shmInfo); + } else { + ATB_SPEED_LOG_DEBUG("get root info from share memory"); + ShmGetHcclRootInfo(shm, *shmInfo); + } + + return ShmBarrier(shm, *shmInfo); +} + +void HcclRunner::ShmGetHcclRootInfo(ShareMemory &shm, const CommInitInfo &shmInfo) +{ + bool commIdReady = false; + while (!commIdReady) { + shm.SemLock(); + if (shmInfo.signal != 0) { + hcclRootInfo_ = shmInfo.hcclRootInfo; + commIdReady = true; + } + shm.SemUnLock(); + if (commIdReady) { + break; + } + } +} + +void HcclRunner::ShmSetHcclRootInfo(ShareMemory &shm, CommInitInfo &shmInfo) +{ + shm.SemLock(); + shmInfo.hcclRootInfo = hcclRootInfo_; + shmInfo.signal = 1; + shm.SemUnLock(); +} + +void HcclRunner::ShmSetReady(ShareMemory &shm, CommInitInfo &shmInfo) const +{ + shm.SemLock(); + shmInfo.barrier[rank_] = true; + shm.SemUnLock(); +} + +bool HcclRunner::ShmBarrier(ShareMemory &shm, CommInitInfo &shmInfo) +{ + ATB_SPEED_LOG_DEBUG("barrier start, rank:" << rank_); + ShmSetReady(shm, shmInfo); + + ATB_SPEED_LOG_DEBUG("check all ready start"); + const double timeout = 600; // 600: 10 minutes timeout + time_t startTime = time(nullptr); + bool endSignal = false; + while (!endSignal) { + time_t currentTime = time(nullptr); + if (difftime(currentTime, startTime) > timeout) { + ATB_SPEED_LOG_ERROR("barrier fail, check all ready timeout"); + endSignal = true; + return false; + } + + bool allReady = true; + shm.SemLock(); + for (int i = 0; i < rankSize_; i++) { + if (!shmInfo.barrier[i]) { + allReady = false; + break; + } + } + shm.SemUnLock(); + if (allReady) { + ATB_SPEED_LOG_DEBUG("check all ready success"); + break; + } + } + + ATB_SPEED_LOG_DEBUG("barrier success, rank:" << rank_); + return true; +} + +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/core/utils/match.cpp b/tests/proftest/layer_test_framework/core/utils/match.cpp new file mode 100644 index 00000000..67cf6053 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/utils/match.cpp @@ -0,0 +1,36 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/utils/match.h" +#include + +namespace atb_speed { +bool StartsWith(const std::string &text, const std::string &prefix) +{ + if (text.data() == nullptr || prefix.data() == nullptr) { + return false; + } + return prefix.empty() || (text.size() >= prefix.size() && memcmp(text.data(), prefix.data(), prefix.size()) == 0); +} + +bool EndsWith(const std::string &text, const std::string &suffix) +{ + if (text.data() == nullptr || suffix.data() == nullptr) { + return false; + } + return suffix.empty() || (text.size() >= suffix.size() && + memcmp(text.data() + (text.size() - suffix.size()), suffix.data(), suffix.size()) == 0); +} +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/core/utils/model_factory.cpp b/tests/proftest/layer_test_framework/core/utils/model_factory.cpp new file mode 100644 index 00000000..381811c0 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/utils/model_factory.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/utils/model_factory.h" +#include "atb_speed/log.h" + +namespace atb_speed { +bool ModelFactory::Register(const std::string &modelName, CreateModelFuncPtr createModel) +{ + auto it = ModelFactory::GetRegistryMap().find(modelName); + if (it != ModelFactory::GetRegistryMap().end()) { + if (it->second == nullptr) { + ATB_SPEED_LOG_ERROR("Find modelName error: " << modelName); + return false; + } + ATB_SPEED_LOG_WARN(modelName << " model already exists, but the duplication doesn't matter."); + return false; + } + ModelFactory::GetRegistryMap()[modelName] = createModel; + return true; +} + +std::shared_ptr ModelFactory::CreateInstance(const std::string &modelName, const std::string ¶m) +{ + auto it = ModelFactory::GetRegistryMap().find(modelName); + if (it != ModelFactory::GetRegistryMap().end()) { + ATB_SPEED_LOG_DEBUG("Find model: " << modelName); + return it->second(param); + } + ATB_SPEED_LOG_WARN("ModelName: " << modelName << " not find in model factory map"); + return nullptr; +} + +std::unordered_map &ModelFactory::GetRegistryMap() +{ + static std::unordered_map modelRegistryMap; + return modelRegistryMap; +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/utils/operation_factory.cpp b/tests/proftest/layer_test_framework/core/utils/operation_factory.cpp new file mode 100644 index 00000000..c6a210c0 --- /dev/null +++ b/tests/proftest/layer_test_framework/core/utils/operation_factory.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/utils/operation_factory.h" +#include "atb_speed/log.h" + +namespace atb_speed { +bool OperationFactory::Register(const std::string &operationName, CreateOperationFuncPtr createOperation) +{ + auto it = OperationFactory::GetRegistryMap().find(operationName); + if (it != OperationFactory::GetRegistryMap().end()) { + ATB_SPEED_LOG_WARN(operationName << " operation already exists, but the duplication doesn't matter."); + return false; + } + OperationFactory::GetRegistryMap()[operationName] = createOperation; + return true; +} + +atb::Operation *OperationFactory::CreateOperation(const std::string &operationName, const nlohmann::json ¶m) +{ + auto it = OperationFactory::GetRegistryMap().find(operationName); + if (it != OperationFactory::GetRegistryMap().end()) { + if (it->second == nullptr) { + ATB_SPEED_LOG_ERROR("Find operation error: " << operationName); + return nullptr; + } + ATB_SPEED_LOG_DEBUG("Find operation: " << operationName); + return it->second(param); + } + ATB_SPEED_LOG_WARN("OperationName: " << operationName << " not find in operation factory map"); + return nullptr; +} + +std::unordered_map &OperationFactory::GetRegistryMap() +{ + static std::unordered_map operationRegistryMap; + return operationRegistryMap; +} +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/core/utils/share_memory.cpp b/tests/proftest/layer_test_framework/core/utils/share_memory.cpp new file mode 100644 index 00000000..a76d7cdd --- /dev/null +++ b/tests/proftest/layer_test_framework/core/utils/share_memory.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/utils/share_memory.h" +#include +#include +#include +#include +#include +#include +#include "atb_speed/log.h" +#include "atb_speed/log.h" + +constexpr int SEM_TIMEOUT = 300; + +ShareMemory::ShareMemory(const std::string &name, uint32_t size) : memSize_(size) +{ + sem_ = sem_open(name.c_str(), O_CREAT, S_IRUSR | S_IWUSR, 1); + if (SEM_FAILED == sem_) { + ATB_SPEED_LOG_ERROR("share memory open fail, name:" << name); + return; + } + ATB_SPEED_LOG_DEBUG("create share memory begin, name:" << name); + + SemLock(); + shareMemory_ = (uint8_t *)CreateShareMemory(name, memSize_); + ATB_SPEED_LOG_DEBUG("create share memory success"); + SemUnLock(); +} + +void *ShareMemory::GetShm() +{ + return shareMemory_; +}; + +void ShareMemory::SemLock() const +{ + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + ts.tv_sec += SEM_TIMEOUT; + int ret = sem_timedwait(sem_, &ts); + // 等待信号量超时 + if (ret == -1 && errno == ETIMEDOUT) { + ATB_SPEED_LOG_ERROR("The semaphore waiting duration exceeds 5 minutes. Run the " + "rm -rf /dev/shm/sem." << + fullName_ << " command to clear the semaphore."); + } +}; + +void ShareMemory::SemUnLock() const +{ + sem_post(sem_); +}; + +void *ShareMemory::CreateShareMemory(const std::string &name, uint32_t size) +{ + void *memory = nullptr; + struct shmid_ds buf; + key_t key = static_cast(std::hash{}(name)); + shmid_ = shmget(key, size, IPC_CREAT | 0600); // 0600提供文件所有者有读和写的权限 + ATB_SPEED_LOG_DEBUG("key: " << key << " shmid :" << shmid_); + if (shmid_ == -1) { + ATB_SPEED_LOG_ERROR("shmget err, " << "errno is: " <(-1)) { + ATB_SPEED_LOG_ERROR("shmmat err, " << "errno is: " < + +namespace atb_speed { +constexpr int OPGRAPH_NAME_MAX_LENG = 128; + +std::string GetFuncNameAndNameSpace(const std::string &inputStr) +{ + int spaceInd = 0; + int leftBracketInd = 0; + std::string extractStr; + int inputStrLen = static_cast(inputStr.size()); + for (int i = 0; i < inputStrLen; i++) { + if (inputStr.at(i) == ' ') { + spaceInd = i; + } else if (inputStr.at(i) == '(') { + leftBracketInd = i; + break; + } + } + if (spaceInd >= 0 && (leftBracketInd - spaceInd) > 0) { + int len; + if (leftBracketInd - (spaceInd + 1) > OPGRAPH_NAME_MAX_LENG) { + len = OPGRAPH_NAME_MAX_LENG; + } else { + len = leftBracketInd - (spaceInd + 1); + } + extractStr = inputStr.substr(spaceInd + 1, len); + } else { + extractStr = inputStr; + } + + for (char &i : extractStr) { + if (!isalnum(i) && i != '_') { + i = '_'; + } + } + return extractStr; +} + +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/utils/tensor_util.cpp b/tests/proftest/layer_test_framework/core/utils/tensor_util.cpp new file mode 100644 index 00000000..0db6025d --- /dev/null +++ b/tests/proftest/layer_test_framework/core/utils/tensor_util.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/utils/tensor_util.h" +#include +#include +#include +#include +#include "atb_speed/log.h" +#include "atb_speed/utils/check_util.h" + +namespace atb_speed { +std::string TensorUtil::TensorToString(const atb::Tensor &tensor) +{ + std::stringstream ss; + ss << TensorDescToString(tensor.desc) << ", deviceData:" << tensor.deviceData << ", hostData:" << tensor.hostData + << ", dataSize:" << tensor.dataSize; + return ss.str(); +} + +std::string TensorUtil::TensorDescToString(const atb::TensorDesc &tensorDesc) +{ + std::stringstream ss; + ss << "dtype: " << tensorDesc.dtype << ", format: " << tensorDesc.format << ", shape:["; + for (size_t i = 0; i < tensorDesc.shape.dimNum; ++i) { + if (i == 0) { + ss << tensorDesc.shape.dims[i]; + } else { + ss << ", " << tensorDesc.shape.dims[i]; + } + } + ss << "]"; + + return ss.str(); +} + +uint64_t TensorUtil::GetTensorNumel(const atb::Tensor &tensor) { return GetTensorNumel(tensor.desc); } + +uint64_t TensorUtil::GetTensorNumel(const atb::TensorDesc &tensorDesc) +{ + if (tensorDesc.shape.dimNum == 0) { + return 0; + } + + int64_t elementCount = 1; + for (size_t i = 0; i < tensorDesc.shape.dimNum; i++) { + elementCount = CheckIntMulOverFlow(elementCount, tensorDesc.shape.dims[i]); + } + + return elementCount; +} + +bool TensorUtil::TensorDescEqual(const atb::TensorDesc &tensorDescA, const atb::TensorDesc &tensorDescB) +{ + if (tensorDescA.dtype == tensorDescB.dtype && tensorDescA.format == tensorDescB.format && + tensorDescA.shape.dimNum == tensorDescB.shape.dimNum) { + for (size_t i = 0; i < tensorDescA.shape.dimNum; i++) { + if (tensorDescA.shape.dims[i] != tensorDescB.shape.dims[i]) { + return false; + } + } + return true; + } + return false; +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/core/utils/timer.cpp b/tests/proftest/layer_test_framework/core/utils/timer.cpp new file mode 100644 index 00000000..1851c88d --- /dev/null +++ b/tests/proftest/layer_test_framework/core/utils/timer.cpp @@ -0,0 +1,44 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/utils/timer.h" +#include + +namespace atb_speed { +const uint64_t MICRSECOND_PER_SECOND = 1000000; + +Timer::Timer() { startTimepoint_ = GetCurrentTimepoint(); } + +Timer::~Timer() {} + +uint64_t Timer::ElapsedMicroSecond() +{ + uint64_t now = GetCurrentTimepoint(); + uint64_t use = now - startTimepoint_; + startTimepoint_ = now; + return use; +} + +void Timer::Reset() { startTimepoint_ = GetCurrentTimepoint(); } + +uint64_t Timer::GetCurrentTimepoint() const +{ + struct timeval tv; + gettimeofday(&tv, nullptr); + uint64_t ret = + static_cast(tv.tv_sec * MICRSECOND_PER_SECOND + tv.tv_usec); + return ret; +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/base/layer/decoder_layer.cpp b/tests/proftest/layer_test_framework/models/base/layer/decoder_layer.cpp new file mode 100644 index 00000000..58618bd9 --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/layer/decoder_layer.cpp @@ -0,0 +1,975 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/norm/norm_linear.h" +#include "operations/aclrt/ops/aclrt_cmo_async.h" + +#include "models/base/layer/decoder_layer.h" + +#include "atb_speed/base/event_manager.h" + +namespace atb_speed { +namespace base { + +template +DecoderLayer::DecoderLayer(const LayerParam ¶m) +{ + this->param = param; + this->param.CheckParam(); + this->inTensorCandidates = { + {"input_norm_weight", { + // shape: [hiddenSize] + "in_input_norm_weight", "in_input_norm_bias", "in_input_norm_new_weight", "in_input_norm_new_bias"}}, + {"attn_weight", this->attnWeight}, + {"mlp_weight", this->mlpWeight}, + {"post_attn_norm_weight", { + // shape: [hiddenSize] + "in_post_attn_norm_weight", "in_post_attn_norm_bias", "in_post_attn_norm_new_weight", + "in_post_attn_norm_new_bias"}}, + {"kv_quant_scale", { + "in_k_quant_scale", "in_k_dequant_scale", "in_v_quant_scale", "in_v_dequant_scale"}}, + {"kv_quant_offset", { + "in_k_quant_offset", "in_k_dequant_offset", "in_v_quant_offset", "in_v_dequant_offset"}}, + {"fa3_quant", { + "in_q_quant_scale", "in_k_quant_scale", "in_v_quant_scale", "in_qk_descale", + "q_offset", "kv_offset", "fa3_v_quant_scale", "fa3_offset"}}, + {"reduce_quant_attn", { + "in_attn_reduce_quant_scale", "in_attn_reduce_quant_offset", + "in_attn_gather_quant_scale", "in_attn_gather_quant_offset"}}, + {"reduce_quant_mlp", { + "in_mlp_reduce_quant_scale", "in_mlp_reduce_quant_offset", + "in_mlp_gather_quant_scale", "in_mlp_gather_quant_offset"}}, + {"default", { + "in_hidden_states", // shape: FA: [batchSize, seqLen, hiddenSize] PA: [seqLen, hiddenSize] + "in_cos_embedding", "in_sin_embedding", "in_attention_mask", "in_k_cache", "in_v_cache", "in_seq_len", + "in_token_offset", "in_layer_id", "in_block_tables", "in_slots"}}, + {"compress_head_alibi", {"wins_global", "in_ra_seqlens"}}, + {"compress_head_rope", {"wins_global", "in_ra_seqlens", "pffset_index", "razor_offset", + "in_reshape_seqlen"}}, // [batchSize * Numhead] + {"q_len", {"in_q_len"}}, + {"lora_common", {"in_seq_len_cum_sum"}}, + {"lora_attn", { + "in_qkv_lora_a_0", "in_qkv_lora_b_0", "in_qkv_lora_a_1", "in_qkv_lora_b_1", + "in_qkv_lora_a_2", "in_qkv_lora_b_2", "in_qkv_dense_lora_a", "in_qkv_dense_lora_b"} + }, + {"lora_mlp", { + "in_mlp_lora_a_0", "in_mlp_lora_b_0", "in_mlp_lora_a_1", "in_mlp_lora_b_1", + "in_mlp_down_lora_a", "in_mlp_down_lora_b"} + }, + {"attn_dp", { + "in_final_hidden_state", "in_shard_effective_token_indices", "in_token_index_with_padding", + "in_skip_padding_token_indices"} + }, + {"input_add_norm", {"in_last_mlp_out"}}, + {"add_rmsnorm_quant", {"in_qkv_scale_fill", "in_qkv_offset_fill", "in_mlp_scale_fill", "in_mlp_offset_fill"}}, + {"qk_norm", {"q_norm_weight", "k_norm_weight"}}, + {"flash_comm", { + "send_counts", "sdispls", "send_count", "recv_counts", "rdispls", "recv_count", + "fake_rs_shape", "fake_ag_shape"}}, + }; + SetDefaultInternalTensorCandidates(); +} + +template +void DecoderLayer::SetDefaultInternalTensorCandidates() +{ + if (this->param.isAttnSkipLayer) { + this->internalTensorCandidates = {{"default", {"intermediate_mlp_out"}}}; + } else if (this->param.isMlpSkipLayer) { + this->internalTensorCandidates = {{"default", {"intermediate_attn_out"}}}; + } else { + this->internalTensorCandidates = {{"default", {"intermediate_attn_out"}}}; + if (this->param.layerId == (this->param.numHiddenLayers - 1) || !this->param.enableInterLayerAddNorm) { + this->internalTensorCandidates["default"].push_back("intermediate_mlp_out"); + } + } + + if (this->param.hasAttnDp) { + this->internalTensorCandidates[std::string("attn_dp")] = { + "intermediate_dp_attn_out_with_padding", "intermediate_dp_attn_out_all_with_padding", + "intermediate_dp_attn_gathered"}; + } +} + +template +void DecoderLayer::ConstructInTensorMap() +{ + this->inTensorList.clear(); + // 添加默认的Tensor + atb_speed::common::AddTensorToList(this->inTensorCandidates, "input_norm_weight", this->inTensorList); + atb_speed::common::AddTensorToList(this->inTensorCandidates, "attn_weight", this->inTensorList); + atb_speed::common::AddTensorToList(this->inTensorCandidates, "post_attn_norm_weight", this->inTensorList); + atb_speed::common::AddTensorToList(this->inTensorCandidates, "mlp_weight", this->inTensorList); + // 添加 QKNorm 特性的Tensor + if (param.useQKNorm) { + atb_speed::common::AddTensorToList(this->inTensorCandidates, "qk_norm", this->inTensorList); + } + // 添加AddRmsNormQuant特性的Tensor + if (param.enableInterLayerAddNorm || param.enableIntraLayerAddNorm) { + atb_speed::common::AddTensorToList(this->inTensorCandidates, "add_rmsnorm_quant", this->inTensorList); + } + // 添加KV cache int8特性的Tensor + if (param.enableKvQuant) { + atb_speed::common::AddTensorToList(this->inTensorCandidates, "kv_quant_scale", this->inTensorList); + if (param.kvQuantHasOffset) { + atb_speed::common::AddTensorToList(this->inTensorCandidates, "kv_quant_offset", this->inTensorList); + } + } + + // 添加FA3特性的Tensor + if (param.enableFA3) { + atb_speed::common::AddTensorToList(this->inTensorCandidates, "fa3_quant", this->inTensorList); + } + + // 添加lccl reduce int8特性的Tensor + if (param.enableReduceQuant) { + atb_speed::common::AddTensorToList(this->inTensorCandidates, "reduce_quant_attn", this->inTensorList); + atb_speed::common::AddTensorToList(this->inTensorCandidates, "reduce_quant_mlp", this->inTensorList); + } + + atb_speed::common::AddTensorToList(this->inTensorCandidates, "default", this->inTensorList); + atb_speed::common::AddTensorToList( + this->internalTensorCandidates, "default", this->intermediateTensorList); + + // 添加头压缩特性的Tensor + if (param.enableCompressHead) { + atb_speed::common::AddTensorToList( + this->inTensorCandidates, + param.positionEmbeddingType == PositionEmbeddingType::ALIBI ? "compress_head_alibi" : "compress_head_rope", + this->inTensorList); + } + + // 添加omniattention特性的Tensor + if (param.enableOmniAttention) { + atb_speed::common::AddTensorToList(this->inTensorCandidates, "compress_head_rope", this->inTensorList); + } + + // 添加并行解码特性或SplitFuse的Tensor + if (param.enableSpeculate || param.enableSplitFuse) { + atb_speed::common::AddTensorToList(this->inTensorCandidates, "q_len", this->inTensorList); + } + + // 添加lora特性的Tensor + if (param.enableLora) { + atb_speed::common::AddTensorToList(this->inTensorCandidates, "lora_common", this->inTensorList); + atb_speed::common::AddTensorToList(this->inTensorCandidates, "lora_attn", this->inTensorList); + atb_speed::common::AddTensorToList(this->inTensorCandidates, "lora_mlp", this->inTensorList); + } + + if (param.hasAttnDp) { + atb_speed::common::AddTensorToList(this->inTensorCandidates, "attn_dp", this->inTensorList); + } + // 添加AddNorm融合特性的Tensor + if (param.enableInterLayerAddNorm && param.layerId != 0) { + atb_speed::common::AddTensorToList(this->inTensorCandidates, "input_add_norm", this->inTensorList); + } + + // Add Flashcomm 1.0 Tensor + if (param.enableFlashComm) { + atb_speed::common::AddTensorToList(this->inTensorCandidates, "flash_comm", this->inTensorList); + } +} + +template +void DecoderLayer::ConstructInternalTensorMap() +{ + this->intermediateTensorList.clear(); + atb_speed::common::AddTensorToList( + this->internalTensorCandidates, "default", this->intermediateTensorList); + if (this->param.hasAttnDp && this->param.hasMlpTp) { + atb_speed::common::AddTensorToList( + this->internalTensorCandidates, "attn_dp", this->intermediateTensorList); + } +} + +template +int64_t DecoderLayer::BuildGraph(atb::Operation **operation) +{ + this->graph.name = param.isPrefill ? "Prefill_layer" : "Decoder_layer"; + this->ConstructInTensorMap(); + this->ConstructInternalTensorMap(); + this->graph.inTensorNum = this->inTensorList.size(); + ATB_SPEED_LOG_DEBUG("this->graph.inTensorNum " << this->graph.inTensorNum); + this->graph.internalTensorNum = this->intermediateTensorList.size(); + ATB_SPEED_LOG_DEBUG("this->graph.internalTensorNum " << this->graph.internalTensorNum); + if (this->param.hasAttnDp && this->param.hasMlpTp) { + this->outTensorList.push_back("out_attndp_last_layer"); + } + if (this->param.enableInterLayerAddNorm && (this->param.layerId != (this->param.numHiddenLayers - 1))) { + this->outTensorList.push_back("out_mlp"); + } + this->graph.outTensorNum = this->outTensorList.size(); + ATB_SPEED_LOG_DEBUG("this->graph.outTensorNum " << this->graph.outTensorNum); + this->tensorMap = atb_speed::common::GetTensorMap( + this->inTensorList, this->outTensorList, this->intermediateTensorList); + std::stringstream ss; + // 添加layer层 map打印 + for (auto tensor = this->tensorMap.cbegin(); tensor != this->tensorMap.cend(); ++tensor) { + ss << "tensor name: " << tensor->first << ", tensor id: " << tensor->second << std::endl; + } + ATB_SPEED_LOG_DEBUG("layer map tensor:\n" << ss.str()); + + CHECK_OPERATION_STATUS_RETURN(this->AddOperationToGraph()); + + uint32_t inHiddenStatesIdx = atb_speed::common::GetTensorIdx(this->tensorMap, "in_hidden_states"); + if (param.hasAttnDp && param.hasMlpTp) { + uint32_t inHiddenStatesIdx2 = atb_speed::common::GetTensorIdx(this->tensorMap, "in_final_hidden_state"); + this->graph.inferShapeFunc = [inHiddenStatesIdx, inHiddenStatesIdx2]( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(inHiddenStatesIdx); + outTensorDescs.at(1) = inTensorDescs.at(inHiddenStatesIdx2); + return atb::NO_ERROR; + }; + } else { + bool outputAddNorm = this->param.enableInterLayerAddNorm && \ + (this->param.layerId != (this->param.numHiddenLayers - 1)); + this->graph.inferShapeFunc = [inHiddenStatesIdx, outputAddNorm]( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(inHiddenStatesIdx); + if (outputAddNorm) { + outTensorDescs.at(1) = inTensorDescs.at(inHiddenStatesIdx); + } + return atb::NO_ERROR; + }; + } + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(this->graph, operation)); + return atb::NO_ERROR; +} + +template +atb::Status DecoderLayer::AddOperationToGraph() +{ + if (!param.isAttnSkipLayer) { + CHECK_OPERATION_STATUS_RETURN(this->AddFusionAttention()); + CHECK_OPERATION_STATUS_RETURN(this->AddFusionAttentionResidualAdd()); + if (param.hasAttnDp && param.hasMlpTp) { + CHECK_OPERATION_STATUS_RETURN(this->AddFusedAllGather()); + } + } + + if (!param.isMlpSkipLayer) { + CHECK_OPERATION_STATUS_RETURN(this->AddMlp()); + CHECK_OPERATION_STATUS_RETURN(this->AddMlpResidualAdd()); + if (param.hasAttnDp && param.hasMlpTp) { + CHECK_OPERATION_STATUS_RETURN(this->AddRevertAllGather()); + ATB_SPEED_LOG_DEBUG("Revert AllGather finished"); + } + } + return atb::NO_ERROR; +} + +template +void DecoderLayer::SetFusionAttentionParam( + atb_speed::common::FusionAttentionParam &fusionAttentionParam) +{ + fusionAttentionParam.enableAddNorm = param.enableInterLayerAddNorm && (param.layerId != 0); + this->SetFusionAttentionNormParam(fusionAttentionParam); + this->SetFusionAttentionLinearParam(fusionAttentionParam); + + // rope param + if (param.positionEmbeddingType == ROPE) { + fusionAttentionParam.rotaryType = atb_speed::common::RotaryType::ALL_ROTARY; + fusionAttentionParam.ropeParam.rotaryCoeff = 2; // 2: 旋转系数 + fusionAttentionParam.selfAttentionParam.maskType = atb::infer::SelfAttentionParam::MaskType::MASK_TYPE_NORM; + } else if (param.positionEmbeddingType == ALIBI) { + fusionAttentionParam.rotaryType = atb_speed::common::RotaryType::NO_ROTARY; + fusionAttentionParam.selfAttentionParam.maskType = atb::infer::SelfAttentionParam::MaskType::MASK_TYPE_ALIBI; + fusionAttentionParam.pageAttentionParam.maskType = atb::infer::PagedAttentionParam::MaskType::MASK_TYPE_ALIBI; + } else if (param.positionEmbeddingType == ABSOLUTE) { + fusionAttentionParam.rotaryType = atb_speed::common::RotaryType::NO_ROTARY; + fusionAttentionParam.selfAttentionParam.maskType = atb::infer::SelfAttentionParam::MaskType::MASK_TYPE_NORM; + fusionAttentionParam.pageAttentionParam.maskType = atb::infer::PagedAttentionParam::MaskType::UNDEFINED; + } + + // attention param + fusionAttentionParam.isFA = param.isFA; + fusionAttentionParam.isPrefill = param.isPrefill; + fusionAttentionParam.enableSplitFuse = param.enableSplitFuse; + fusionAttentionParam.headDim = param.hiddenSizePerAttentionHead; + fusionAttentionParam.attnBackend = param.attnBackend; + fusionAttentionParam.enableRopeQuantKvcache = param.enableRopeQuantKvcache; + fusionAttentionParam.useQKNorm = param.useQKNorm; + fusionAttentionParam.enableFlashComm = param.enableFlashComm; + // self attention + this->SetFusionAttentionATBSelfAttentionParam(fusionAttentionParam); + // paged attention + this->SetFusionAttentionATBPagedAttentionParam(fusionAttentionParam); + // aclnnIncreAttention + this->SetFusionAttentionAclNNIncreAttentionParam(fusionAttentionParam); + // self out linear param + fusionAttentionParam.denseQuantType = atb_speed::common::ConvertQuantTypeToPackType(param.weightQuantType); +} + +template<> +void DecoderLayer::SetFusionAttentionNormParam( + atb_speed::common::FusionAttentionParam &fusionAttentionParam) +{ + atb::infer::RmsNormParam attenRmsNormParam; + attenRmsNormParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM; + attenRmsNormParam.normParam.epsilon = this->param.normEps; + fusionAttentionParam.normParamType = attenRmsNormParam; + atb::infer::RmsNormParam attenRmsNormQuantParam; + attenRmsNormQuantParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM; + attenRmsNormQuantParam.normParam.epsilon = this->param.normEps; + attenRmsNormQuantParam.normParam.quantType = atb::infer::QUANT_INT8; + fusionAttentionParam.normQuantParamType = attenRmsNormQuantParam; + if (fusionAttentionParam.enableAddNorm) { + fusionAttentionParam.normParamType.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_PRENORM; + fusionAttentionParam.normParamType.preNormParam.epsilon = param.normEps; + fusionAttentionParam.normQuantParamType.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_PRENORM; + fusionAttentionParam.normQuantParamType.preNormParam.epsilon = param.normEps; + fusionAttentionParam.normQuantParamType.preNormParam.quantType = atb::infer::QUANT_INT8; + } +} + +template<> +void DecoderLayer::SetFusionAttentionNormParam( + atb_speed::common::FusionAttentionParam &fusionAttentionParam) +{ + const int32_t beginParamsAxis = param.isFA ? 2 : 1; + atb::infer::LayerNormParam attenLayerNormParam; + attenLayerNormParam.layerType = atb::infer::LayerNormParam::LayerNormType::LAYER_NORM_NORM; + attenLayerNormParam.normParam.epsilon = this->param.normEps; + attenLayerNormParam.normParam.beginNormAxis = beginParamsAxis; + attenLayerNormParam.normParam.beginParamsAxis = 1; + fusionAttentionParam.normParamType = attenLayerNormParam; + atb::infer::LayerNormParam attenLayerNormQuantParam; + attenLayerNormQuantParam.layerType = atb::infer::LayerNormParam::LayerNormType::LAYER_NORM_NORM; + attenLayerNormQuantParam.normParam.epsilon = this->param.normEps; + attenLayerNormQuantParam.normParam.quantType = atb::infer::QUANT_INT8; + attenLayerNormQuantParam.normParam.beginNormAxis = beginParamsAxis; + attenLayerNormQuantParam.normParam.beginParamsAxis = 1; + fusionAttentionParam.normQuantParamType = attenLayerNormQuantParam; + fusionAttentionParam.normHasBias = true; +} + +template +void DecoderLayer::SetFusionAttentionLinearParam( + atb_speed::common::FusionAttentionParam &fusionAttentionParam) +{ + // QKV param + fusionAttentionParam.isGroupedQueryAttention = \ + this->param.numAttentionHeadsPerRank != param.numKeyValueHeadsPerRank; + fusionAttentionParam.isBF16 = this->param.isBF16; + fusionAttentionParam.isAntiOutlier = this->param.isAntiOutlier.at(0); + fusionAttentionParam.layerLinearDescs = this->param.linearDescs; + fusionAttentionParam.layerLinearQuantType = this->param.linearQuantType; + fusionAttentionParam.layerLinearTransposeType = this->param.linearTransposeType; + fusionAttentionParam.packQuantType = this->param.packQuantType.at(0); + fusionAttentionParam.quantGroupSize = this->param.quantGroupSize; + fusionAttentionParam.matmulBackend = this->param.matmulBackend; + fusionAttentionParam.supportLora = this->param.enableLora; + fusionAttentionParam.enablePreFetchWeight = this->param.enablePreFetchWeight; + fusionAttentionParam.enableMC2 = param.enableMC2; + fusionAttentionParam.loraEnableGMM = this->param.loraEnableGMM; + fusionAttentionParam.qkvHasBias = this->param.linearHasBias.at(QKV_HASBIAS); + fusionAttentionParam.enableModelConfuscation = this->param.enableModelConfuscation; + fusionAttentionParam.modelConfuscationFd = this->param.modelConfuscationFd; + fusionAttentionParam.hiddenSizePerRank = \ + CheckIntMulOverFlow(this->param.hiddenSizePerAttentionHead, this->param.numAttentionHeadsPerRank); + // dense + fusionAttentionParam.selfAttnHasBias = this->param.linearHasBias.at(SELFATTENTION_HASBIAS); + fusionAttentionParam.supportLcoc = this->param.enableLcoc; + if (this->param.hasAttnDp) { + fusionAttentionParam.selfOutLinearTensorParallelInfo = { + this->param.attnTpRank, this->param.attnTpSize, this->param.backend, this->param.attnTpRankTableFile, + nullptr, this->param.attnTpDomain}; + } else { + fusionAttentionParam.selfOutLinearTensorParallelInfo = this->param.tensorParallelInfo; + if (this->param.mapping.isInitialized_) { + atb_speed::common::ParallelInfo parallelInfo = param.mapping.Get(base::ATTN_TP); + parallelInfo.InitCommDomain( + fusionAttentionParam.selfOutLinearTensorParallelInfo.hcommInfo, + fusionAttentionParam.selfOutLinearTensorParallelInfo.commDomain); + } + } + if (this->param.enableReduceQuant) { + fusionAttentionParam.selfOutLinearTensorParallelInfo.quantType = \ + atb::infer::AllReduceParam::QuantType::QUANT_TYPE_PER_CHANNEL; + fusionAttentionParam.selfOutLinearTensorParallelInfo.outDataType = ACL_FLOAT16; + } +} + +template +void DecoderLayer::SetFusionAttentionATBSelfAttentionParam( + atb_speed::common::FusionAttentionParam &fusionAttentionParam) +{ + fusionAttentionParam.selfAttentionParam.headNum = this->param.numAttentionHeadsPerRank; + fusionAttentionParam.selfAttentionParam.kvHeadNum = this->param.numKeyValueHeadsPerRank; + fusionAttentionParam.selfAttentionParam.qkScale = 1.0 / sqrt(this->param.hiddenSizePerAttentionHead); + if (this->param.isFA) { + fusionAttentionParam.selfAttentionParam.calcType = this->param.isPrefill ? \ + atb::infer::SelfAttentionParam::CalcType::ENCODER : atb::infer::SelfAttentionParam::CalcType::DECODER; + } else { + fusionAttentionParam.selfAttentionParam.isTriuMask = this->param.isPrefill ? 1 : 0; + fusionAttentionParam.selfAttentionParam.calcType = atb::infer::SelfAttentionParam::CalcType::PA_ENCODER; + } + if (this->param.attnBackend == atb_speed::common::OpBackend::ACLNN && param.isFA) { + fusionAttentionParam.selfAttentionParam.calcType = atb::infer::SelfAttentionParam::CalcType::PA_ENCODER; + } +} + +template +void DecoderLayer::SetFusionAttentionATBPagedAttentionParam( + atb_speed::common::FusionAttentionParam &fusionAttentionParam) +{ + fusionAttentionParam.pageAttentionParam.headNum = this->param.numAttentionHeadsPerRank; + fusionAttentionParam.pageAttentionParam.kvHeadNum = this->param.numKeyValueHeadsPerRank; + fusionAttentionParam.pageAttentionParam.qkScale = 1.0 / sqrt(this->param.hiddenSizePerAttentionHead); + if (this->param.enableCompressHead) { + if (this->param.positionEmbeddingType == PositionEmbeddingType::ROPE) { + fusionAttentionParam.pageAttentionParam.compressType = atb::infer::PagedAttentionParam::CompressType:: \ + COMPRESS_TYPE_KVHEAD_ROPE; + fusionAttentionParam.reshapeCacheParm.compressType = atb::infer::ReshapeAndCacheParam::CompressType:: \ + COMPRESS_TYPE_KVHEAD_ROPE; + } else { + fusionAttentionParam.pageAttentionParam.compressType = atb::infer::PagedAttentionParam::CompressType:: \ + COMPRESS_TYPE_KVHEAD; + fusionAttentionParam.reshapeCacheParm.compressType = atb::infer::ReshapeAndCacheParam::CompressType:: \ + COMPRESS_TYPE_KVHEAD; + } + } + if (this->param.enableOmniAttention) { + fusionAttentionParam.pageAttentionParam.compressType = atb::infer::PagedAttentionParam::CompressType:: \ + COMPRESS_TYPE_KVHEAD_ROPE; + fusionAttentionParam.reshapeCacheParm.compressType = atb::infer::ReshapeAndCacheParam::CompressType:: \ + COMPRESS_TYPE_KVHEAD_ROPE; + fusionAttentionParam.enableOmniattention = true; + fusionAttentionParam.isomnicompressed = this->param.isomnicompressed; + } + if (this->param.enableSpeculate || this->param.enableSplitFuse) { + fusionAttentionParam.pageAttentionParam.calcType = \ + atb::infer::PagedAttentionParam::CalcType::CALC_TYPE_SPEC; + fusionAttentionParam.pageAttentionParam.maskType = \ + atb::infer::PagedAttentionParam::MaskType::MASK_TYPE_SPEC; + } + if (this->param.enableKvQuant) { + fusionAttentionParam.pageAttentionParam.quantType = atb::infer::PagedAttentionParam::TYPE_DEQUANT_FUSION; + fusionAttentionParam.pageAttentionParam.maskType = atb::infer::PagedAttentionParam::UNDEFINED; + fusionAttentionParam.pageAttentionParam.hasQuantOffset = this->param.kvQuantHasOffset; + } + if (this->param.enableFA3) { + fusionAttentionParam.pageAttentionParam.quantType = atb::infer::PagedAttentionParam::TYPE_QUANT_QKV_ONLINE; + if (this->param.isBF16) { + fusionAttentionParam.pageAttentionParam.outDataType = ACL_BF16; + } else { + fusionAttentionParam.pageAttentionParam.outDataType = ACL_FLOAT16; + } + } +} + +template +void DecoderLayer::SetFusionAttentionAclNNIncreAttentionParam( + atb_speed::common::FusionAttentionParam &fusionAttentionParam) +{ + fusionAttentionParam.aclnnIncreAttentionParam.headNum = this->param.numAttentionHeadsPerRank; + fusionAttentionParam.aclnnIncreAttentionParam.kvHeadNum = this->param.numKeyValueHeadsPerRank; + fusionAttentionParam.aclnnIncreAttentionParam.headDim = this->param.hiddenSizePerAttentionHead; + fusionAttentionParam.aclnnIncreAttentionParam.hasMask = true; + fusionAttentionParam.aclnnIncreAttentionParam.isFA = this->param.isFA; + fusionAttentionParam.aclnnIncreAttentionParam.hasKVQuant = this->param.enableKvQuant; + if (this->param.enableKvQuant) { + fusionAttentionParam.aclnnIncreAttentionParam.hasQuantOffset = this->param.kvQuantHasOffset; + } +} + +template <> +atb::Status DecoderLayer::CreateFusionAttentionOperation(atb::Operation **op) +{ + atb_speed::common::FusionAttentionParam fusionAttentionParam; + this->SetFusionAttentionParam(fusionAttentionParam); + CHECK_OPERATION_STATUS_RETURN(Attention(fusionAttentionParam, op)); + return atb::NO_ERROR; +} + +template <> +atb::Status DecoderLayer::CreateFusionAttentionOperation(atb::Operation **op) +{ + atb_speed::common::FusionAttentionParam fusionAttentionParam; + this->SetFusionAttentionParam(fusionAttentionParam); + CHECK_OPERATION_STATUS_RETURN(Attention(fusionAttentionParam, op)); + return atb::NO_ERROR; +} + +template +std::map> DecoderLayer::GetAttentionIntensor() +{ + std::map> attnInTensor = {}; + attnInTensor[common::AttnInTensorCategory::ATTN_DEFAULT] = { + "in_hidden_states", "in_input_norm_weight", "in_input_norm_bias", "in_input_norm_new_weight", + "in_input_norm_new_bias", "in_qkv_weight_0", "in_qkv_scale_0", "in_qkv_offset_0", "in_qkv_descale_0", + "in_qkv_bias_0", "in_qkv_compress_idx_0", "in_qkv_weight_1", "in_qkv_scale_1", "in_qkv_offset_1", + "in_qkv_descale_1", "in_qkv_bias_1", "in_qkv_compress_idx_1", "in_qkv_weight_2", "in_qkv_scale_2", + "in_qkv_offset_2", "in_qkv_descale_2", "in_qkv_bias_2", "in_qkv_compress_idx_2", "in_cos_embedding", + "in_sin_embedding", "in_seq_len", "in_k_cache", "in_v_cache", "in_attention_mask", "in_token_offset", + "in_layer_id", "in_block_tables", "in_slots", "in_qkv_dense_weight", "in_qkv_dense_scale", + "in_qkv_dense_offset", "in_qkv_dense_descale", "in_qkv_dense_bias", "in_qkv_dense_compress_idx"}; + if (this->param.enableCompressHead) { + if (this->param.positionEmbeddingType == PositionEmbeddingType::ALIBI) { + attnInTensor[common::AttnInTensorCategory::ATTN_COMPRESS_HEAD_ALIBI] = \ + this->inTensorCandidates["compress_head_alibi"]; + } else { + attnInTensor[common::AttnInTensorCategory::ATTN_COMPRESS_HEAD_ROPE] = \ + this->inTensorCandidates["compress_head_rope"]; + } + } + if (this->param.enableOmniAttention) { + attnInTensor[common::AttnInTensorCategory::ATTN_OMNI] = \ + this->inTensorCandidates["compress_head_rope"]; + } + if (this->param.enableSpeculate || param.enableSplitFuse) { + attnInTensor[common::AttnInTensorCategory::ATTN_SPECULATE] = this->inTensorCandidates["q_len"]; + } + if (this->param.enableKvQuant) { + attnInTensor[common::AttnInTensorCategory::ATTN_KV_QUANT_SCALE] = this->inTensorCandidates["kv_quant_scale"]; + if (this->param.kvQuantHasOffset) { + attnInTensor[common::AttnInTensorCategory::ATTN_KV_QUANT_OFFSET] = \ + this->inTensorCandidates["kv_quant_offset"]; + } + } + if (this->param.enableFA3) { + attnInTensor[common::AttnInTensorCategory::ATTN_FA3] = this->inTensorCandidates["fa3_quant"]; + } + if (this->param.enableLora) { + attnInTensor[common::AttnInTensorCategory::ATTN_LORA] = {"in_seq_len_cum_sum"}; + for (std::string tensor : this->inTensorCandidates.at("lora_attn")) { + attnInTensor[common::AttnInTensorCategory::ATTN_LORA].push_back(tensor); + } + } + if (this->param.enableReduceQuant) { + attnInTensor[common::AttnInTensorCategory::ATTN_REDUCE_QUANT] = {}; + for (std::string tensor : this->inTensorCandidates.at("reduce_quant_attn")) { + attnInTensor[common::AttnInTensorCategory::ATTN_REDUCE_QUANT].push_back(tensor); + } + } + if (this->param.useQKNorm) { + attnInTensor[common::AttnInTensorCategory::ATTN_QK_NORM] = this->inTensorCandidates["qk_norm"]; + } + if (this->param.enableInterLayerAddNorm && (this->param.layerId != 0)) { + attnInTensor[common::AttnInTensorCategory::ATTN_ADD_RMS_NORM_QUANT].push_back("in_qkv_scale_fill"); + attnInTensor[common::AttnInTensorCategory::ATTN_ADD_RMS_NORM_QUANT].push_back("in_qkv_offset_fill"); + attnInTensor[common::AttnInTensorCategory::ATTN_ADD_NORM] = {"in_last_mlp_out"}; + } + if (this->param.enablePreFetchWeight) { + attnInTensor[common::AttnInTensorCategory::ATTN_CMO] = {"in_mlp_weight_0"}; + } + if (this->param.enableFlashComm) { + attnInTensor[common::AttnInTensorCategory::ATTN_FC].push_back("send_counts"); + attnInTensor[common::AttnInTensorCategory::ATTN_FC].push_back("sdispls"); + attnInTensor[common::AttnInTensorCategory::ATTN_FC].push_back("send_count"); + attnInTensor[common::AttnInTensorCategory::ATTN_FC].push_back("recv_counts"); + attnInTensor[common::AttnInTensorCategory::ATTN_FC].push_back("rdispls"); + attnInTensor[common::AttnInTensorCategory::ATTN_FC].push_back("recv_count"); + attnInTensor[common::AttnInTensorCategory::ATTN_FC].push_back("fake_rs_shape"); + attnInTensor[common::AttnInTensorCategory::ATTN_FC].push_back("fake_ag_shape"); + } + return attnInTensor; +} + +template +atb::Status DecoderLayer::AddFusionAttention() +{ + atb::Node attentionNode; + CHECK_OPERATION_STATUS_RETURN(this->CreateFusionAttentionOperation(&attentionNode.operation)); + + // 按指定顺序对输入tensor进行排序 + std::map> attnInTensor = this->GetAttentionIntensor(); + std::vector attnInTensorNames = {}; + for (unsigned int i = 0; i < common::AttnInTensorCategory::ATTN_END; i++) { + attnInTensorNames.insert(attnInTensorNames.end(), attnInTensor[i].begin(), attnInTensor[i].end()); + } + + attentionNode.inTensorIds = atb_speed::common::GetTensorIdxList(this->tensorMap, attnInTensorNames); + std::vector attnOutTensorName = {"intermediate_attn_out"}; + if (this->param.enableInterLayerAddNorm && (this->param.layerId != 0)) { + attnOutTensorName.push_back("in_hidden_states"); + } + attentionNode.outTensorIds = atb_speed::common::GetTensorIdxList(this->tensorMap, attnOutTensorName); + + this->graph.nodes.push_back(attentionNode); + + if (this->param.enablePreFetchWeight && !this->param.isPrefill) { + atb::Node computeRecordNode; + computeRecordNode.inTensorIds = {}; + computeRecordNode.outTensorIds = {}; + CHECK_OPERATION_STATUS_RETURN(atb_speed::EventManager::GetInstance().RecordEvent( + computeRecordNode.operation, + atb_speed::EventAction::PUSH, + atb_speed::common::CMO_COMPUTE)); + this->graph.nodes.push_back(computeRecordNode); + + atb::Node commWaitNode; + commWaitNode.inTensorIds = {}; + commWaitNode.outTensorIds = {}; + CHECK_OPERATION_STATUS_RETURN(atb_speed::EventManager::GetInstance().WaitEvent( + commWaitNode.operation, + atb_speed::EventAction::POP, + atb_speed::common::CMO_COMPUTE)); + atb::SetExecuteStreamId(commWaitNode.operation, 1); + this->graph.nodes.push_back(commWaitNode); + + atb::Node cmoNode; + cmoNode.operation = new atb_speed::common::AclrtCmoAsyncOperation("AclrtCmoAsync"); + cmoNode.inTensorIds = atb_speed::common::GetTensorIdxList(this->tensorMap, { + "in_mlp_weight_0" + }); + cmoNode.outTensorIds = atb_speed::common::GetTensorIdxList(this->tensorMap, {}); + atb::SetExecuteStreamId(cmoNode.operation, 1); + + this->graph.nodes.push_back(cmoNode); + } + + return atb::NO_ERROR; +} + +template +atb::Status DecoderLayer::AddFusionAttentionResidualAdd() +{ + atb::infer::ElewiseParam addParam; + addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + atb::Node selfResidualAddNode; + if (!param.enableIntraLayerAddNorm) { + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(addParam, &selfResidualAddNode.operation)); + selfResidualAddNode.inTensorIds = \ + atb_speed::common::GetTensorIdxList(this->tensorMap, {"in_hidden_states", "intermediate_attn_out"}); + if (param.isMlpSkipLayer) { + selfResidualAddNode.outTensorIds = \ + atb_speed::common::GetTensorIdxList(this->tensorMap, {"out"}); + } else { + selfResidualAddNode.outTensorIds = \ + atb_speed::common::GetTensorIdxList(this->tensorMap, {"in_hidden_states"}); + } + this->graph.nodes.push_back(selfResidualAddNode); + } + return atb::NO_ERROR; +} + +template +void DecoderLayer::SetMlpParam(atb_speed::common::MlpParam &mlpParam) +{ + mlpParam.isBF16 = this->param.isBF16; + mlpParam.isPrefill = this->param.isPrefill; + mlpParam.layerLinearQuantType = this->param.linearQuantType; + mlpParam.layerLinearTransposeType = this->param.linearTransposeType; + mlpParam.layerLinearDescs = this->param.linearDescs; + mlpParam.packQuantType = this->param.packQuantType.at(1); + mlpParam.matmulBackend = this->param.matmulBackend; + mlpParam.quantGroupSize = this->param.quantGroupSize; + mlpParam.isEdgeHardware = this->param.isEdgeHardware; + mlpParam.enableFlashComm = param.enableFlashComm; + // norm + mlpParam.isAntiOutlier = this->param.isAntiOutlier.at(1); + this->SetMlpNormParam(mlpParam); + // gate up + mlpParam.mlpPackType = atb_speed::common::GetMlpPackType( + this->param.packQuantType.at(1), false, param.linearDescs); + mlpParam.gateUpHasBias = this->param.linearHasBias.at(atb_speed::base::GATEUP_HASBIAS); + mlpParam.enableAddNorm = this->param.enableIntraLayerAddNorm; + mlpParam.supportLora = this->param.enableLora; + mlpParam.loraEnableGMM = this->param.loraEnableGMM; + // down + mlpParam.downLinearTensorParallelInfo = this->param.tensorParallelInfo; + if (this->param.mapping.isInitialized_) { + atb_speed::common::ParallelInfo parallelInfo = param.mapping.Get(base::MLP_TP); + parallelInfo.InitCommDomain( + mlpParam.downLinearTensorParallelInfo.hcommInfo, + mlpParam.downLinearTensorParallelInfo.commDomain); + } + if (this->param.enableReduceQuant) { + mlpParam.downLinearTensorParallelInfo.quantType = \ + atb::infer::AllReduceParam::QuantType::QUANT_TYPE_PER_CHANNEL; + mlpParam.downLinearTensorParallelInfo.outDataType = ACL_FLOAT16; + } + mlpParam.downHasBias = this->param.linearHasBias.at(atb_speed::base::DOWN_HASBIAS); + mlpParam.supportLcoc = this->param.enableLcoc; + mlpParam.enableMC2 = this->param.enableMC2; + if (this->param.enableSwiGLU) { + mlpParam.activationParam.activationType = atb::infer::ActivationType::ACTIVATION_SWIGLU_FORWARD; + mlpParam.activationParam.dim = -1; + } else { + mlpParam.activationParam.activationType = atb::infer::ActivationType::ACTIVATION_SWISH; + } + mlpParam.downQuantType = atb_speed::common::ConvertQuantTypeToPackType(param.weightQuantType); + mlpParam.enableSwigluQuant = this->param.enableSwigluQuant; +} + +template <> +void DecoderLayer::SetMlpNormParam( + atb_speed::common::MlpParam &mlpParam) +{ + atb::infer::RmsNormParam mlpRmsNormParam; + if (this->param.enableIntraLayerAddNorm) { + mlpRmsNormParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_PRENORM; + mlpRmsNormParam.preNormParam.epsilon = this->param.normEps; + } else { + mlpRmsNormParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM; + mlpRmsNormParam.normParam.epsilon = this->param.normEps; + } + mlpParam.normParamType = mlpRmsNormParam; + atb::infer::RmsNormParam mlpRmsNormQuantParam; + if (this->param.enableIntraLayerAddNorm) { + mlpRmsNormQuantParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_PRENORM; + mlpRmsNormQuantParam.preNormParam.epsilon = this->param.normEps; + mlpRmsNormQuantParam.preNormParam.quantType = atb::infer::QUANT_INT8; + } else { + mlpRmsNormQuantParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM; + mlpRmsNormQuantParam.normParam.epsilon = this->param.normEps; + mlpRmsNormQuantParam.normParam.quantType = atb::infer::QUANT_INT8; + } + mlpParam.normQuantParamType = mlpRmsNormQuantParam; +} + +template <> +void DecoderLayer::SetMlpNormParam( + atb_speed::common::MlpParam &mlpParam) +{ + const int32_t beginParamsAxis = param.isFA ? 2 : 1; + atb::infer::LayerNormParam mlpLayerNormParam; + mlpLayerNormParam.layerType = param.enableIntraLayerAddNorm ? \ + atb::infer::LayerNormParam::LayerNormType::LAYER_NORM_PRENORM : \ + atb::infer::LayerNormParam::LayerNormType::LAYER_NORM_NORM; + mlpLayerNormParam.normParam.epsilon = param.normEps; + mlpLayerNormParam.normParam.beginNormAxis = beginParamsAxis; + mlpLayerNormParam.normParam.beginParamsAxis = 1; + mlpParam.normParamType = mlpLayerNormParam; + atb::infer::LayerNormParam mlpLayerNormQuantParam; + mlpLayerNormQuantParam.layerType = param.enableIntraLayerAddNorm ? \ + atb::infer::LayerNormParam::LayerNormType::LAYER_NORM_PRENORM : \ + atb::infer::LayerNormParam::LayerNormType::LAYER_NORM_NORM; + mlpLayerNormQuantParam.normParam.epsilon = param.normEps; + mlpLayerNormQuantParam.normParam.quantType = atb::infer::QUANT_INT8; + mlpLayerNormQuantParam.normParam.beginNormAxis = beginParamsAxis; + mlpLayerNormQuantParam.normParam.beginParamsAxis = 1; + mlpParam.normQuantParamType = mlpLayerNormQuantParam; + mlpParam.normHasBias = true; +} + +template <> +atb::Status DecoderLayer::CreateMlpOperation(atb::Operation **op) +{ + atb_speed::common::MlpParam mlpParam; + this->SetMlpParam(mlpParam); + if (param.enableSwiGLU) { + CHECK_OPERATION_STATUS_RETURN(MlpSwiGLU(mlpParam, op)); + } else { + CHECK_OPERATION_STATUS_RETURN(Mlp(mlpParam, op)); + } + return atb::NO_ERROR; +} + +template <> +atb::Status DecoderLayer::CreateMlpOperation(atb::Operation **op) +{ + atb_speed::common::MlpParam mlpParam; + this->SetMlpParam(mlpParam); + if (param.enableSwiGLU) { + CHECK_OPERATION_STATUS_RETURN(MlpSwiGLU(mlpParam, op)); + } else { + CHECK_OPERATION_STATUS_RETURN(Mlp(mlpParam, op)); + } + return atb::NO_ERROR; +} + +template +std::map> DecoderLayer::GetMlpIntensor() +{ + std::map> mlpInTensor = {}; + mlpInTensor[common::MlpInTensorCategory::MLP_DEFAULT] = { + "in_hidden_states", "in_post_attn_norm_weight", "in_post_attn_norm_bias", "in_post_attn_norm_new_weight", + "in_post_attn_norm_new_bias", "in_mlp_weight_0", "in_mlp_scale_0", "in_mlp_offset_0", "in_mlp_descale_0", + "in_mlp_bias_0", "in_mlp_compress_idx_0", "in_mlp_weight_1", "in_mlp_scale_1", "in_mlp_offset_1", + "in_mlp_descale_1", "in_mlp_bias_1", "in_mlp_compress_idx_1", "in_mlp_down_weight", "in_mlp_down_scale", + "in_mlp_down_offset", "in_mlp_down_descale", "in_mlp_down_bias", "in_mlp_down_compress_idx" + }; + if (param.enableIntraLayerAddNorm) { + mlpInTensor[common::MlpInTensorCategory::MLP_ADD_RMS_NORM_QUANT].push_back("in_mlp_scale_fill"); + mlpInTensor[common::MlpInTensorCategory::MLP_ADD_RMS_NORM_QUANT].push_back("in_mlp_offset_fill"); + mlpInTensor[common::MlpInTensorCategory::MLP_ADD_NORM] = {"intermediate_attn_out"}; + } + if (param.enableLora) { + mlpInTensor[common::MlpInTensorCategory::MLP_LORA] = {"in_seq_len_cum_sum"}; + for (std::string tensor : this->inTensorCandidates.at("lora_mlp")) { + mlpInTensor[common::MlpInTensorCategory::MLP_LORA].push_back(tensor); + } + } + if (param.enableReduceQuant) { + mlpInTensor[common::MlpInTensorCategory::MLP_REDUCE_QUANT] = {}; + for (std::string tensor : this->inTensorCandidates.at("reduce_quant_mlp")) { + mlpInTensor[common::MlpInTensorCategory::MLP_REDUCE_QUANT].push_back(tensor); + } + } + if (this->param.hasAttnDp && this->param.hasMlpTp) { + mlpInTensor[common::MlpInTensorCategory::MLP_DEFAULT][0] = "intermediate_dp_attn_gathered"; + } + if (this->param.enableFlashComm) { + mlpInTensor[common::MlpInTensorCategory::MLP_FC].push_back("send_counts"); + mlpInTensor[common::MlpInTensorCategory::MLP_FC].push_back("sdispls"); + mlpInTensor[common::MlpInTensorCategory::MLP_FC].push_back("send_count"); + mlpInTensor[common::MlpInTensorCategory::MLP_FC].push_back("recv_counts"); + mlpInTensor[common::MlpInTensorCategory::MLP_FC].push_back("rdispls"); + mlpInTensor[common::MlpInTensorCategory::MLP_FC].push_back("recv_count"); + mlpInTensor[common::MlpInTensorCategory::MLP_FC].push_back("fake_rs_shape"); + mlpInTensor[common::MlpInTensorCategory::MLP_FC].push_back("fake_ag_shape"); + } + return mlpInTensor; +} + +template +atb::Status DecoderLayer::AddMlp() +{ + atb::Node mlpParallelNode; + CHECK_OPERATION_STATUS_RETURN(this->CreateMlpOperation(&mlpParallelNode.operation)); + + // 按指定顺序对输入tensor进行排序 + std::map> mlpInTensor = this->GetMlpIntensor(); + std::vector mlpInTensorNames = {}; + for (unsigned int i = 0; i < common::MlpInTensorCategory::MLP_END; i++) { + mlpInTensorNames.insert(mlpInTensorNames.end(), mlpInTensor[i].begin(), mlpInTensor[i].end()); + } + + std::vector mlpOutTensorName = {"intermediate_mlp_out"}; + if (param.enableInterLayerAddNorm && (param.layerId != (param.numHiddenLayers - 1))) { + mlpOutTensorName = {"out_mlp"}; + } + if (param.enableIntraLayerAddNorm) { + if (param.enableInterLayerAddNorm && (param.layerId != (param.numHiddenLayers - 1))) { + mlpOutTensorName.push_back("out"); + } else { + mlpOutTensorName.push_back("in_hidden_states"); + } + } + + mlpParallelNode.inTensorIds = atb_speed::common::GetTensorIdxList(this->tensorMap, mlpInTensorNames); + mlpParallelNode.outTensorIds = atb_speed::common::GetTensorIdxList(this->tensorMap, mlpOutTensorName); + + this->graph.nodes.push_back(mlpParallelNode); + return atb::NO_ERROR; +} + +template +atb::Status DecoderLayer::AddMlpResidualAdd() +{ + // 如果开启且不是最后一层, 则不走AddMlpResidualAdd逻辑 + if (param.enableInterLayerAddNorm && param.layerId != (param.numHiddenLayers - 1)) { + return atb::NO_ERROR; + } + atb::infer::ElewiseParam addParam; + addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + atb::Node mlpResidualAddNode; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(addParam, &mlpResidualAddNode.operation)); + mlpResidualAddNode.inTensorIds = atb_speed::common::GetTensorIdxList(this->tensorMap, { + param.hasAttnDp && param.hasMlpTp ? "intermediate_dp_attn_gathered" : "in_hidden_states", + "intermediate_mlp_out"}); + if (this->param.hasAttnDp && this->param.hasMlpTp) { + mlpResidualAddNode.outTensorIds = atb_speed::common::GetTensorIdxList( + this->tensorMap, {"out_attndp_last_layer"}); + } else { + mlpResidualAddNode.outTensorIds = atb_speed::common::GetTensorIdxList( + this->tensorMap, {"out"}); + } + + this->graph.nodes.push_back(mlpResidualAddNode); + return atb::NO_ERROR; +} + +template +atb::Status DecoderLayer::AddPadNode() +{ + atb::Node padNode; + atb::infer::GatherParam padParam; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(padParam, &padNode.operation)); + padNode.inTensorIds = atb_speed::common::GetTensorIdxList( + this->tensorMap, {"in_hidden_states", "in_token_index_with_padding"}); + padNode.outTensorIds = atb_speed::common::GetTensorIdxList( + this->tensorMap, {"intermediate_dp_attn_out_with_padding"}); + this->graph.nodes.push_back(padNode); + ATB_SPEED_LOG_DEBUG("Gather calculation success"); + return atb::NO_ERROR; +} + +template +atb::Status DecoderLayer::AddAllGatherNode() +{ + atb::Node allGatherNode; + atb::infer::AllGatherParam allGatherParam; + allGatherParam.rank = this->param.attnDpRank; + allGatherParam.rankSize = this->param.attnDpSize; + allGatherParam.backend = this->param.backend; + allGatherParam.rankTableFile = this->param.attnDpRankTableFile; + allGatherParam.commDomain = this->param.attnDpDomain; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(allGatherParam, &allGatherNode.operation)); + allGatherNode.inTensorIds = atb_speed::common::GetTensorIdxList( + this->tensorMap, {"intermediate_dp_attn_out_with_padding"}); + allGatherNode.outTensorIds = atb_speed::common::GetTensorIdxList( + this->tensorMap, {"intermediate_dp_attn_out_all_with_padding"}); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(this->graph)); + this->graph.nodes.push_back(allGatherNode); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsAfterComm(this->graph)); + ATB_SPEED_LOG_DEBUG("AllGather calculation success"); + return atb::NO_ERROR; +} + +template +atb::Status DecoderLayer::AddUnPadNode() +{ + atb::Node unpadNode; + atb::infer::GatherParam unpadParam; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(unpadParam, &unpadNode.operation)); + unpadNode.inTensorIds = atb_speed::common::GetTensorIdxList( + this->tensorMap, {"intermediate_dp_attn_out_all_with_padding", "in_skip_padding_token_indices"}); + unpadNode.outTensorIds = atb_speed::common::GetTensorIdxList(this->tensorMap, {"intermediate_dp_attn_gathered"}); + unpadNode.inTensorReshapeFuncs.reserve(unpadNode.inTensorIds.size()); + unpadNode.inTensorReshapeFuncs.resize(unpadNode.inTensorIds.size()); + unpadNode.inTensorReshapeFuncs[0] = [](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2: number of dimensions of the new shape + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + newShape.dims[1] = oldShape.dims[2]; // 2: index of desired shape + }; + this->graph.nodes.push_back(unpadNode); + ATB_SPEED_LOG_DEBUG("Gather calculation success"); + return atb::NO_ERROR; +} + +template +atb::Status DecoderLayer::AddFusedAllGather() +{ + CHECK_OPERATION_STATUS_RETURN(AddPadNode()); + CHECK_OPERATION_STATUS_RETURN(AddAllGatherNode()); + CHECK_OPERATION_STATUS_RETURN(AddUnPadNode()); + return atb::NO_ERROR; +} + +template +atb::Status DecoderLayer::AddRevertAllGather() +{ + atb::Node revertAllGatherNode; + atb::infer::GatherParam gatherParam; + atb::CreateOperation(gatherParam, &revertAllGatherNode.operation); + revertAllGatherNode.inTensorIds = atb_speed::common::GetTensorIdxList(this->tensorMap, + {"out_attndp_last_layer", "in_shard_effective_token_indices"}); + revertAllGatherNode.outTensorIds = atb_speed::common::GetTensorIdxList(this->tensorMap, {"out"}); + this->graph.nodes.push_back(revertAllGatherNode); + ATB_SPEED_LOG_DEBUG("create revertAllGatherNode"); + return atb::NO_ERROR; +} + +template class DecoderLayer; +template class DecoderLayer; +} // namespace base +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/base/layer/decoder_layer.h b/tests/proftest/layer_test_framework/models/base/layer/decoder_layer.h new file mode 100644 index 00000000..9502dc5e --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/layer/decoder_layer.h @@ -0,0 +1,187 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_BASE_DECODER_LAYER_H +#define ATB_SPEED_MODELS_BASE_DECODER_LAYER_H + +#include +#include "nlohmann/json.hpp" + +#include "atb/atb_infer.h" +#include "atb_speed/log.h" + +#include "models/base/param/layer_param.h" +#include "operations/fusion/attention/fusion_attention.h" +#include "operations/fusion/mlp/mlp.h" + +namespace atb_speed { +namespace base { + +/// Base class to define the structure of a layer in a large language models, called by `DecoderModel` class. +/// \tparam NormType The type of normalization; refer to `NormType` for more details. +template +class DecoderLayer { +public: + explicit DecoderLayer(const LayerParam ¶m); + virtual ~DecoderLayer() {}; + + /// Create an graph operation that represents the structure of a layer + /// \param operation the address of a pointer to a default operation + /// \return A flag indicates whether the operation was successfully created. + virtual int64_t BuildGraph(atb::Operation **operation); + +protected: + /// Construct a in tensor list by selecting input tensors from all the avaliable input tensor candidates based on + // the provided parameters `param`. + virtual void ConstructInTensorMap(); + /// Construct a in tensor list by selecting input tensors from all the avaliable internal tensor candidates + /// based on the provided parameters `param`. + virtual void ConstructInternalTensorMap(); + /// The main entrance to set the fusion attention module's parameters + /// \param fusionAttentionParam a reference to the funsion attention parameter to be set + virtual void SetFusionAttentionParam(atb_speed::common::FusionAttentionParam &fusionAttentionParam); + /// Configure the parameters of the normalization component within the fusion attention module + /// \param fusionAttentionParam a reference to the funsion attention parameter to be set + virtual void SetFusionAttentionNormParam(atb_speed::common::FusionAttentionParam &fusionAttentionParam); + /// Configure the parameters of the linear component within the fusion attention module + /// \param fusionAttentionParam a reference to the funsion attention parameter to be set + virtual void SetFusionAttentionLinearParam(atb_speed::common::FusionAttentionParam &fusionAttentionParam); + /// Configure the parameters of ATB's self-attention component within the fusion attention module + /// \param fusionAttentionParam a reference to the funsion attention parameter to be set + virtual void SetFusionAttentionATBSelfAttentionParam( + atb_speed::common::FusionAttentionParam &fusionAttentionParam); + /// Configure the parameters of ATB's paged attention component within the fusion attention module + /// \param fusionAttentionParam a reference to the funsion attention parameter to be set + virtual void SetFusionAttentionATBPagedAttentionParam( + atb_speed::common::FusionAttentionParam &fusionAttentionParam); + /// Configure the parameters of AclNN's attention component within the fusion attention module + /// (used in the decode phase) + /// \param fusionAttentionParam a reference to the funsion attention parameter to be set + virtual void SetFusionAttentionAclNNIncreAttentionParam( + atb_speed::common::FusionAttentionParam &fusionAttentionParam); + /// The main entrance to set the mlp module's parameters + //// \param mlpParam a reference to the mlp parameter to be set + virtual void SetMlpParam(atb_speed::common::MlpParam &mlpParam); + /// Configure the parameters of the normalization component within the mlp module + //// \param mlpParam a reference to the mlp parameter to be set + virtual void SetMlpNormParam(atb_speed::common::MlpParam &mlpParam); + /// Create the fusion attention operation + /// \param op the address of a pointer to a default operation + atb::Status CreateFusionAttentionOperation(atb::Operation **op); + /// Create the mlp operation + /// \param op the address of a pointer to a default operation + atb::Status CreateMlpOperation(atb::Operation **op); + /// Get the fusion attention module's input tensor + /// \return A map of the attention module's input tensor categories and their corresponding tensor names + virtual std::map> GetAttentionIntensor(); + /// Get the mlp module's input tensor + /// \return A map of the mlp module's input tensor categories and their corresponding tensor names + virtual std::map> GetMlpIntensor(); + /// Update internal tensor candidates + virtual void SetDefaultInternalTensorCandidates(); + /// The primary entry point for adding all operations to the graph in sequential order. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddOperationToGraph(); + /// Add the fusion attention node to the graph + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddFusionAttention(); + /// Add the residual add node to the graph to conduct the add operation after the fusion attention node + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddFusionAttentionResidualAdd(); + /// Add the mlp node to the graph + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddMlp(); + /// Add the residual add node to the graph to conduct the add operation after the mlp node + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddMlpResidualAdd(); + /// Add the fused all gather node to the graph to gather information across deivces + /// in the same communication domain. This function calls AddPadNode(), AddAllGather(), and AddUnPadNode(). + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddFusedAllGather(); + /// Add the revert all gather node to the graph to drop unnecessary information before the attention operation + /// of the next layer. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddRevertAllGather(); + /// Add the pad node to the graph to make sure that the lengths of the input tensor to AllGather operator + /// across devices of the same communication domain are identical. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddPadNode(); + /// Add the all gather node to the graph to gather information across devices + /// in the same communication domain. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddAllGatherNode(); + /// Add the unpad node to the graph to drop the excessive information padded before AllGather operator. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddUnPadNode(); + + /// Specifies all potential input tensors, where the key represents the feature name, + /// and the value corresponds to the input tensor name. + std::map> inTensorCandidates = {}; + /// Specifies all potential internal tensors, where the key represents the feature name, + /// and the value corresponds to the internal tensor name. + std::map> internalTensorCandidates = {}; + /// A vector contains names of all the required input tensors. + std::vector inTensorList = {}; + /// A vector contains names of all the required intermediate tensors. + std::vector intermediateTensorList = {}; + /// A vector contains names of all the required output tensors. + std::vector outTensorList = {"out"}; + /// Defines all the required tensors for the current graph, with the key representing the input tensor name + /// and the value corresponding to the tensor index. + /// Tensors are ordered by input tensors, output tensors and internal tensors. + std::map tensorMap = {}; + + /// Layer parameters + LayerParam param; + /// A layer graph to be created + atb::GraphParam graph; + +private: + /// Default weight names required by the fusion attention node + const std::vector attnWeight = { + // Pack: + // MHA [3 * numAttentionHeadsPerRank * hiddenSizePerAttentionHead, hiddenSize] + // GQA [(numAttentionHeadsPerRank + 2 * numKeyValueHeadsPerRank) * hiddenSizePerAttentionHead, hiddenSize] + // No pack: + // (Q) shape: [numAttentionHeadsPerRank * hiddenSizePerAttentionHead, hiddenSize] + "in_qkv_weight_0", "in_qkv_bias_0", "in_qkv_descale_0", "in_qkv_offset_0", "in_qkv_scale_0", + "in_qkv_compress_idx_0", + // Pack: no usage; No pack: (K) shape: [numKeyValueHeadsPerRank * hiddenSizePerAttentionHead, hiddenSize] + "in_qkv_weight_1", "in_qkv_bias_1", "in_qkv_descale_1", "in_qkv_offset_1", "in_qkv_scale_1", + "in_qkv_compress_idx_1", + // Pack: no usage; No pack: (V) shape: [numKeyValueHeadsPerRank * hiddenSizePerAttentionHead, hiddenSize] + "in_qkv_weight_2", "in_qkv_bias_2", "in_qkv_descale_2", "in_qkv_offset_2", "in_qkv_scale_2", + "in_qkv_compress_idx_2", + // shape: [hiddenSize, numAttentionHeadsPerRank * hiddenSizePerAttentionHead] + "in_qkv_dense_weight", "in_qkv_dense_bias", "in_qkv_dense_descale", "in_qkv_dense_offset", + "in_qkv_dense_scale", "in_qkv_dense_compress_idx"}; + + /// Default weight names required by the mlp attention node + const std::vector mlpWeight = { + // Pack: shape: [2 * intermediateSizePerRank, hiddenSize] + // No pack: (Gate) shape: [intermediateSizePerRank, hiddenSize] + "in_mlp_weight_0", "in_mlp_bias_0", "in_mlp_descale_0", "in_mlp_offset_0", "in_mlp_scale_0", + "in_mlp_compress_idx_0", + // Pack: no usage; No pack: (Up) shape: [intermediateSizePerRank, hiddenSize] + "in_mlp_weight_1", "in_mlp_bias_1", "in_mlp_descale_1", "in_mlp_offset_1", "in_mlp_scale_1", + "in_mlp_compress_idx_1", + // shape: [hiddenSize, intermediateSizePerRank] + "in_mlp_down_weight", "in_mlp_down_bias", "in_mlp_down_descale", "in_mlp_down_offset", + "in_mlp_down_scale", "in_mlp_down_compress_idx"}; +}; + +} // namespace base +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/models/base/model/decoder_model.cpp b/tests/proftest/layer_test_framework/models/base/model/decoder_model.cpp new file mode 100644 index 00000000..03776032 --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/model/decoder_model.cpp @@ -0,0 +1,1362 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "vector" +#include "nlohmann/json.hpp" +#include "atb/atb_infer.h" +#include "atb_speed/log.h" + +#include "atb_speed/utils/hccl_runner.h" +#include "models/base/model/decoder_model.h" +#include "models/base/layer/decoder_layer.h" +#include "operations/aclnn/ops/split_with_size_operation.h" +#include "operations/fusion/infer_shape_functions.h" + +namespace atb_speed { +namespace base { + +HcclComm DecoderModel::gHcommInfo = nullptr; + +DecoderModel::DecoderModel(const std::string ¶m) : Model("DecoderModel", param) +{ + this->param.FromString(param); + this->modelName_ += this->param.isPrefill ? "_Prefill" : "_Decoder"; + this->inTensorCandidates = { + {"default", { + "input_ids", "positional_ids", "cosine_table", "sine_table", "attention_mask", + "block_tables", "slots", "kv_cache_idx", "token_offset", "place_holder", "seq_len", "logits_indices"} + }, + {"token_off_set", {"logits_offset_tensor"}}, + {"compress_head_alibi", {"wins_global", "in_ra_seqlens"}}, + {"compress_head_rope_common", {"wins_global", "in_reshape_seqlen"}}, + {"compress_head_rope_per_layer", { + "ra_block_tables", "ra_slots", "in_ra_seqlens", "pffset_index", "razor_offset"}}, + {"q_len", {"q_len"}}, + {"lora_common", {"seq_len_cum_sum"}}, + {"lora_per_layer", { + "qkv_lora_a_0", "qkv_lora_b_0", "qkv_lora_a_1", "qkv_lora_b_1", + "qkv_lora_a_2", "qkv_lora_b_2", "qkv_dense_lora_a", "qkv_dense_lora_b", + "mlp_lora_a_0", "mlp_lora_b_0", "mlp_lora_a_1", "mlp_lora_b_1", + "mlp_down_lora_a", "mlp_down_lora_b"}}, + {"attn_dp", { + "in_final_hidden_state", "in_shard_effective_token_indices", "in_token_index_with_padding", + "in_skip_padding_token_indices"}}, + {"flash_comm", { + "send_counts", "sdispls", "send_count", "recv_counts", "rdispls", "recv_count", + "fake_rs_shape", "fake_ag_shape"}}, + }; + if (this->param.skipWordEmbedding) { + this->inTensorCandidates["default"].at(0) = "input_embedding"; + } + + this->internalTensorCandidates = { + {"default", {"hidden_states"}}, + {"rope", {"cosine_embedding", "sine_embedding"}}, + {"attn_dp", {"attn_dp_last_layer"}}, + {"input_add_norm", {"last_layer_mlp_out"}} + }; + + if (this->param.enableFlashComm) { + std::vector hiddenStatesRanks; + hiddenStatesRanks.resize(this->param.worldSize); + for (int i = 0; i < this->param.worldSize; i++) { + hiddenStatesRanks[i] = "hidden_states_rank" + std::to_string(i); + } + this->internalTensorCandidates["hidden_states_ranks"] = hiddenStatesRanks; + this->internalTensorCandidates["Allgather"] = {"hidden_states_allgather"}; + } + + this->outTensorCandidates = { + {"default", {"logits"}}, + }; + + if (gHcommInfo == nullptr && this->param.backend == "hccl" && this->param.enableMC2) { + atb_speed::HcclRunner hcclRunner(this->param.rank, this->param.worldSize, 0); + gHcommInfo = hcclRunner.CreateHcclCommInMulitProcessByRootInfo(); + } +} + +DecoderModel::~DecoderModel() {} + +void DecoderModel::ConstructInTensorMap() +{ + this->inTensorMap.clear(); + // 添加默认的Tensor + atb_speed::common::AssignTensorIdx(this->inTensorCandidates, "default", this->inTensorMap); + + // 添加头压缩特性的Tensor + if (this->param.enableCompressHead) { + if (param.positionEmbeddingType == PositionEmbeddingType::ALIBI) { + atb_speed::common::AssignTensorIdx(this->inTensorCandidates, "compress_head_alibi", this->inTensorMap); + } else if (param.positionEmbeddingType == PositionEmbeddingType::ROPE) { + atb_speed::common::AssignTensorIdx( + this->inTensorCandidates, "compress_head_rope_common", this->inTensorMap); + uint32_t currentTensorIdx = this->inTensorMap.size(); + for (uint32_t i = 0; i < this->param.numHiddenLayers; ++i) { + for (std::string raInputName : this->inTensorCandidates.at("compress_head_rope_per_layer")) { + this->inTensorMap["layer_" + std::to_string(i) + "_" + raInputName] = currentTensorIdx; + currentTensorIdx++; + } + } + } + } + // 添加omniattention特性的Tensor + if (this->param.enableOmniAttention) { + atb_speed::common::AssignTensorIdx( + this->inTensorCandidates, "compress_head_rope_common", this->inTensorMap); + uint32_t currentTensorIdx = this->inTensorMap.size(); + for (uint32_t i = 0; i < this->param.numHiddenLayers; ++i) { + for (std::string raInputName : this->inTensorCandidates.at("compress_head_rope_per_layer")) { + this->inTensorMap["layer_" + std::to_string(i) + "_" + raInputName] = currentTensorIdx; + currentTensorIdx++; + } + } + } + // 添加并行解码特性或SplitFuse的Tensor + if (this->param.enableSpeculate || this->param.enableSplitFuse) { + atb_speed::common::AssignTensorIdx( + this->inTensorCandidates, "q_len", this->inTensorMap); + } + + // 添加lora特性的Tensor + if (this->param.enableLora) { + atb_speed::common::AssignTensorIdx( + this->inTensorCandidates, "lora_common", this->inTensorMap); + uint32_t currentTensorIdx = this->inTensorMap.size(); + for (uint32_t i = 0; i < this->param.numHiddenLayers; i++) { + for (std::string loraWeightName : this->inTensorCandidates.at("lora_per_layer")) { + this->inTensorMap["layer_" + std::to_string(i) + loraWeightName] = currentTensorIdx; + currentTensorIdx++; + } + } + } + + // Append in-tensors for data parallelism of Attention + if (this->param.hasAttnDp) { + atb_speed::common::AssignTensorIdx( + this->inTensorCandidates, "attn_dp", this->inTensorMap); + } + + // 添加flashcomm1.0特性的Tensor + if (this->param.enableFlashComm) { + atb_speed::common::AssignTensorIdx(this->inTensorCandidates, "flash_comm", this->inTensorMap); + } +} + +void DecoderModel::ConstructInternalTensorMap() +{ + this->internalTensorMap.clear(); + // 添加默认的Tensor + if (!this->param.skipWordEmbedding) { + atb_speed::common::AssignTensorIdx( + this->internalTensorCandidates, "default", this->internalTensorMap); + } + + // 添加rope的Tensor + if (this->param.positionEmbeddingType == PositionEmbeddingType::ROPE) { + atb_speed::common::AssignTensorIdx( + this->internalTensorCandidates, "rope", this->internalTensorMap); + } + + // Append internal-tensors for data parallelism of Attention + if (this->param.hasAttnDp && this->param.hasMlpTp) { + atb_speed::common::AssignTensorIdx( + this->internalTensorCandidates, "attn_dp", this->internalTensorMap); + } + // 添加add rmsnorm融合特性的中间tensor + if (this->param.enableInterLayerAddNorm) { + atb_speed::common::AssignTensorIdx( + this->internalTensorCandidates, "input_add_norm", this->internalTensorMap); + } + if (this->param.enableFlashComm) { + atb_speed::common::AssignTensorIdx(this->internalTensorCandidates, + "hidden_states_ranks", this->internalTensorMap); + atb_speed::common::AssignTensorIdx(this->internalTensorCandidates, "Allgather", this->internalTensorMap); + } +} + +void DecoderModel::ConstructOutTensorMap() +{ + this->outTensorMap.clear(); + // 添加默认的Tensor + atb_speed::common::AssignTensorIdx( + this->outTensorCandidates, "default", this->outTensorMap); +} + +void DecoderModel::PrintTensorMapInfo(std::map &tensorMap) const +{ + std::stringstream ss; + ss << "TensorMap Info: "; + for (auto tensor = tensorMap.cbegin(); tensor != tensorMap.cend(); ++tensor) { + ss << "tensor name: " << tensor->first << ", tensor id: " << tensor->second << std::endl; + } + ATB_SPEED_LOG_DEBUG(ss.str()); +} + +uint32_t DecoderModel::GetInputNum() const { return graph_.inTensors.size(); } + +uint32_t DecoderModel::GetOutputNum() const { return graph_.outTensors.size(); } + +uint32_t DecoderModel::CalcWeightTensorSize() +{ + if (this->param.enableKvQuant) { + this->weightCountPerLayer += 8; // 8: kv cache int8 多8个inTensor + } + if (this->param.enableFA3) { + this->weightCountPerLayer += 8; // 8: FA3 多8个inTensorensor + } + if (this->param.enableReduceQuant) { + this->weightCountPerLayer += 8; // 8: lccl reduce int8 多8个inTensor + } + if (this->param.enableInterLayerAddNorm || this->param.enableIntraLayerAddNorm) { + this->weightCountPerLayer += 4; // 4: addRmsNormQuant 多4个inTensor + } + if (this->param.normType == LAYER_NORM) { + this->weightCountFinalNorm = 2; // 2: LayerNorm 权重数量 + } + if (this->param.useQKNorm) { + this->weightCountPerLayer += 2; // 2: useQKNorm 多2个inTensor + } + const uint64_t weightTensorSize = + this->weightCountWordEmbedding + + CheckIntMulOverFlow(this->weightCountPerLayer, this->param.numHiddenLayers) + + this->weightCountFinalNorm + this->weightCountLmHead; + return weightTensorSize; +} + +void DecoderModel::DuplicateTensorMapForDap( + std::map &tensorMap, std::vector &targetTensors) +{ + std::map tensorIndexMap = {}; + tensorIndexMap = CopyMapWithSuffix(tensorMap); + targetTensors.resize(tensorMap.size()); + std::stringstream ss; + ss << "Dap preceder to successor mapping info: "; + for (auto pair = tensorIndexMap.cbegin(); pair != tensorIndexMap.cend(); ++pair) { + this->precederToSuccessorTensorMap[&targetTensors.at(pair->first)] = &targetTensors.at(pair->second); + ss << "tensor src index: " << pair->first << ", tensor dst index: " << pair->second << std::endl; + } + ATB_SPEED_LOG_DEBUG(ss.str()); +} + +std::map DecoderModel::CopyMapWithSuffix(std::map& tensorMap) const +{ + std::map tmpTensorMap = {}; + std::map tensorIndexMap = {}; + std::string suffix = GetSingleton().GetSuccessorSuffix(); + uint32_t tensorMapSize = tensorMap.size(); + for (auto pair = tensorMap.cbegin(); pair != tensorMap.cend(); ++pair) { + tmpTensorMap[pair->first + suffix] = pair->second + tensorMapSize; + tensorIndexMap[pair->second] = pair->second + tensorMapSize; + } + for (auto pair = tmpTensorMap.cbegin(); pair != tmpTensorMap.cend(); ++pair) { + tensorMap[pair->first] = pair->second; + } + return tensorIndexMap; +} + +void DecoderModel::ReplaceDapTensors(std::vector& tensors) +{ + for (uint32_t i = 0; i < tensors.size(); i++) { + auto it = this->precederToSuccessorTensorMap.find(tensors[i]); + if (it != this->precederToSuccessorTensorMap.end()) { + tensors[i] = it->second; + } + } +} + +atb::Status DecoderModel::InferShape( + const std::vector &inTensorDescs, + std::vector &outTensorDescs +) +{ + const uint64_t RESULT_DIM_2 = 2; + ATB_SPEED_LOG_DEBUG("Enter DecoderModel InferShape"); + if (outTensorDescs.size() != GetOutputNum()) { + return atb::ERROR_INVALID_GRAPH; + } + uint32_t logitsIndicesIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, "logits_indices"); + std::string inputKey = this->param.skipWordEmbedding ? "input_embedding" : "input_ids"; + uint32_t inputIdsIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, inputKey); + CHECK_TENSORDESC_DIMNUM_VALID(graph_.weightTensors.at(graph_.weightTensors.size() - 1).desc.shape.dimNum); + CHECK_TENSORDESC_DIMNUM_VALID(inTensorDescs.at(logitsIndicesIdx).shape.dimNum); + CHECK_TENSORDESC_DIMNUM_VALID(inTensorDescs.at(inputIdsIdx).shape.dimNum); + uint32_t dim = this->param.lmHeadTransposeType == atb_speed::common::TransposeType::NOT_TRANSPOSE ? 1 : 0; + const int64_t vocabSizePerRank = graph_.weightTensors.at(graph_.weightTensors.size() - 1).desc.shape.dims[dim]; + int64_t seqLenAxis = this->param.isUnpadInputs ? 0 : 1; // 2, 3: Axis + if (!this->param.enableGreedyPostProcessing) { + // unpadInputs: [batchSize, seqLen, vocabSize] padInputs: [seqLen, vocabSisze] + outTensorDescs.at(0).dtype = graph_.weightTensors.at(graph_.weightTensors.size() - 1).desc.dtype; + outTensorDescs.at(0).format = graph_.weightTensors.at(0).desc.format; + outTensorDescs.at(0).shape.dimNum = this->param.isUnpadInputs ? 2 : 3; // 2, 3: dimNum + CHECK_TENSORDESC_DIMNUM_VALID(outTensorDescs.at(0).shape.dimNum); + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(inputIdsIdx).shape.dims[0]; + if (this->param.isPrefill || this->param.enablePrefixCache || this->param.hasAttnDp) { + outTensorDescs.at(0).shape.dims[seqLenAxis] = inTensorDescs.at(logitsIndicesIdx).shape.dims[0]; + } else { + outTensorDescs.at(0).shape.dims[seqLenAxis] = inTensorDescs.at(inputIdsIdx).shape.dims[seqLenAxis]; + } + outTensorDescs.at(0).shape.dims[outTensorDescs.at(0).shape.dimNum - 1] = this->param.isLmHeadParallel + ? CheckIntMulOverFlow(vocabSizePerRank, this->param.hasPp ? this->param.tpWorldSize : this->param.worldSize) + : vocabSizePerRank; + } else { + outTensorDescs.at(0).dtype = aclDataType::ACL_INT64; + outTensorDescs.at(0).format = graph_.weightTensors.at(0).desc.format; + outTensorDescs.at(0).shape.dimNum = RESULT_DIM_2; // 二维 [batch_size,1] + outTensorDescs.at(0).shape.dims[0] = + inTensorDescs.at(10).shape.dims[0]; // num 10 on behalf of seq_len, dims[0] is batch_size + outTensorDescs.at(0).shape.dims[1] = 1; + } + + if (this->param.enableDap) { + GetSingleton().SetRole(common::DapRole::SUCCESSOR); + std::string suffix = GetSingleton().GetSuccessorSuffix(); + logitsIndicesIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, "logits_indices" + suffix); + inputIdsIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, inputKey + suffix); + CHECK_TENSORDESC_DIMNUM_VALID(inTensorDescs.at(logitsIndicesIdx).shape.dimNum); + CHECK_TENSORDESC_DIMNUM_VALID(inTensorDescs.at(inputIdsIdx).shape.dimNum); + outTensorDescs.at(1) = outTensorDescs.at(0); + if (this->param.isPrefill || this->param.enablePrefixCache || this->param.hasAttnDp) { + outTensorDescs.at(1).shape.dims[seqLenAxis] = inTensorDescs.at(logitsIndicesIdx).shape.dims[0]; + } else { + outTensorDescs.at(1).shape.dims[seqLenAxis] = inTensorDescs.at(inputIdsIdx).shape.dims[seqLenAxis]; + } + GetSingleton().SetRole(common::DapRole::PRECEDER); + } + return atb::NO_ERROR; +} + +int64_t DecoderModel::BuildGraph() +{ + // 准备inTensor + this->ConstructInTensorMap(); + this->PrintTensorMapInfo(this->inTensorMap); + this->graph_.inTensors.resize(this->inTensorMap.size()); + if (this->param.enableDap) { + this->DuplicateTensorMapForDap(this->inTensorMap, this->graph_.inTensors); + } + ATB_SPEED_LOG_DEBUG("graph_.inTensors " << this->graph_.inTensors.size()); + + // 准备internalTensor + this->ConstructInternalTensorMap(); + this->PrintTensorMapInfo(this->internalTensorMap); + this->graph_.internalTensors.resize(this->internalTensorMap.size()); + if (this->param.enableDap) { + this->DuplicateTensorMapForDap(this->internalTensorMap, this->graph_.internalTensors); + } + ATB_SPEED_LOG_DEBUG("graph_.internalTensors " << this->graph_.internalTensors.size()); + + // 准备outTensor + this->ConstructOutTensorMap(); + this->PrintTensorMapInfo(this->outTensorMap); + this->graph_.outTensors.resize(this->outTensorMap.size()); + if (this->param.enableDap) { + this->DuplicateTensorMapForDap(this->outTensorMap, this->graph_.outTensors); + } + ATB_SPEED_LOG_DEBUG("graph_.outTensors " << this->graph_.outTensors.size()); + + // 准备weightTensor + graph_.weightTensors.resize(this->CalcWeightTensorSize()); + ATB_SPEED_LOG_DEBUG("graph_.weightTensors " << this->graph_.weightTensors.size()); + + // 准备kv cache + graph_.kCacheTensors.resize(this->param.numHiddenLayers); + graph_.vCacheTensors.resize(this->param.numHiddenLayers); + + GetSingleton().SetRole(common::DapRole::UNDEFINED_ROLE); + GetSingleton().Reset(); + auto ret = this->AddOperationToGraph(); + ATB_SPEED_LOG_DEBUG(GetSingleton().PrintCommInfo()); + return ret; +} + +atb::Status DecoderModel::AddOperationToGraph() +{ + std::stringstream ss; + atb::Operation *op = nullptr; + + // AddNodesBeforeLayer + // PRECEDER Events + if (this->param.enableDap) { + GetSingleton().SetRole(common::DapRole::PRECEDER); + } + CHECK_OPERATION_STATUS_RETURN(this->AddNodesBeforeLayer()); + + // SUCCESSOR Events + if (this->param.enableDap) { + GetSingleton().SetRole(common::DapRole::SUCCESSOR); + atb_speed::Model::Node computeWaitNode; + CHECK_OPERATION_STATUS_RETURN(atb_speed::EventManager::GetInstance().WaitEvent( + op, atb_speed::EventAction::POP, common::COMPUTE_EVENT)); + computeWaitNode.inTensors = {}; + computeWaitNode.outTensors = {}; + computeWaitNode.operation.reset(op); + CHECK_OPERATION_STATUS_RETURN(SetNodeStreamId(computeWaitNode, 1)); + graph_.nodes.push_back(computeWaitNode); + ss.str(""); + ss << "[Events] [SUCCESSOR] [POP] [WAIT] [COMPUTE] will be pushed to the graph later"; + ATB_SPEED_LOG_DEBUG(ss.str()); + + uint32_t nodeCount = graph_.nodes.size(); + CHECK_OPERATION_STATUS_RETURN(this->AddNodesBeforeLayer()); + for (uint32_t index = nodeCount; index < graph_.nodes.size(); index++) { + CHECK_OPERATION_STATUS_RETURN(SetNodeStreamId(graph_.nodes.at(index), 1)); + ReplaceDapTensors(graph_.nodes.at(index).inTensors); + ReplaceDapTensors(graph_.nodes.at(index).outTensors); + } + GetSingleton().SetRole(common::DapRole::PRECEDER); + } + + // AddLayer + CHECK_OPERATION_STATUS_RETURN(this->AddLayer()); + + // AddNodesAfterLayer + // PRECEDER Events + CHECK_OPERATION_STATUS_RETURN(this->AddNodesAfterLayer()); + if (param.enableDap) { + CHECK_OPERATION_STATUS_RETURN( + atb_speed::EventManager::GetInstance().WaitEvent(op, atb_speed::EventAction::PUSH, common::COMM_EVENT)); + atb_speed::Model::Node commWaitNode; + commWaitNode.inTensors = {}; + commWaitNode.outTensors = {}; + commWaitNode.operation.reset(op); + graph_.nodes.push_back(commWaitNode); + ATB_SPEED_LOG_DEBUG("[Events] [PRECEDER] [PUSH] [WAIT] [COMM]"); + + CHECK_OPERATION_STATUS_RETURN(atb_speed::EventManager::GetInstance().RecordEvent( + op, atb_speed::EventAction::PUSH, common::COMPUTE_EVENT)); + atb_speed::Model::Node computeRecordNode; + computeRecordNode.inTensors = {}; + computeRecordNode.outTensors = {}; + computeRecordNode.operation.reset(op); + graph_.nodes.push_back(computeRecordNode); + ATB_SPEED_LOG_DEBUG("[Events] [PRECEDER] [PUSH] [RECORD] [COMPUTE]"); + + CHECK_OPERATION_STATUS_RETURN( + atb_speed::EventManager::GetInstance().WaitEvent(op, atb_speed::EventAction::POP, common::END_EVENT)); + atb_speed::Model::Node endWaitNode; + endWaitNode.inTensors = {}; + endWaitNode.outTensors = {}; + endWaitNode.operation.reset(op); + graph_.nodes.push_back(endWaitNode); + ATB_SPEED_LOG_DEBUG("[Events] [PRECEDER] [POP] [WAIT] [END]"); + } + + // SUCCESSOR Events + if (this->param.enableDap) { + GetSingleton().SetRole(common::DapRole::SUCCESSOR); + uint32_t nodeCount = graph_.nodes.size(); + CHECK_OPERATION_STATUS_RETURN(this->AddNodesAfterLayer()); + for (uint32_t index = nodeCount; index < graph_.nodes.size(); index++) { + CHECK_OPERATION_STATUS_RETURN(SetNodeStreamId(graph_.nodes.at(index), 1)); + ReplaceDapTensors(graph_.nodes.at(index).inTensors); + ReplaceDapTensors(graph_.nodes.at(index).outTensors); + } + + CHECK_OPERATION_STATUS_RETURN( + atb_speed::EventManager::GetInstance().RecordEvent(op, atb_speed::EventAction::PUSH, common::END_EVENT)); + atb_speed::Model::Node endRecordNode; + endRecordNode.inTensors = {}; + endRecordNode.outTensors = {}; + endRecordNode.operation.reset(op); + CHECK_OPERATION_STATUS_RETURN(SetNodeStreamId(endRecordNode, 1)); + graph_.nodes.push_back(endRecordNode); + ATB_SPEED_LOG_DEBUG("[Events] [SUCCESSOR] [PUSH] [RECORD] [END]"); + GetSingleton().SetRole(common::DapRole::PRECEDER); + } + + return atb::NO_ERROR; +} + +atb::Status DecoderModel::AddNodesBeforeLayer() +{ + if (!this->param.skipWordEmbedding) { CHECK_OPERATION_STATUS_RETURN(this->AddWordEmbedding()); } + if (this->param.positionEmbeddingType == PositionEmbeddingType::ROPE) { + CHECK_OPERATION_STATUS_RETURN(this->AddPositionalEmbedding()); + } + if (this->param.enableFlashComm) { + CHECK_OPERATION_STATUS_RETURN(this->AddSplitHiddenStates()); + } + return atb::NO_ERROR; +} + +atb::Status DecoderModel::AddNodesAfterLayer() +{ + CHECK_OPERATION_STATUS_RETURN(this->AddFinalNorm()); + if (this->param.enableFlashComm) { + CHECK_OPERATION_STATUS_RETURN(this->AddAllGather()); + } + CHECK_OPERATION_STATUS_RETURN(this->AddLmhead()); + return atb::NO_ERROR; +} + +void DecoderModel::SetWordEmbeddingParam(atb_speed::common::WordEmbeddingParam &wordEmbeddingParam) +{ + wordEmbeddingParam.unpadInputs = this->param.isUnpadInputs; + if (this->param.isEmbeddingParallel && this->param.hasPp) { + wordEmbeddingParam.tensorParallelInfo = { + this->param.tpRank, this->param.tpWorldSize, this->param.backend, this->param.tpRankTableFile, \ + nullptr, this->param.tpDomain + }; + } else if (this->param.isEmbeddingParallel && this->param.hasAttnDp) { + wordEmbeddingParam.tensorParallelInfo = { + this->param.attnTpRank, this->param.attnTpSize, this->param.backend, this->param.attnTpRankTableFile, + nullptr, this->param.attnTpDomain + }; + } else if (this->param.isEmbeddingParallel) { + wordEmbeddingParam.tensorParallelInfo = { + this->param.rank, this->param.worldSize, this->param.backend, this->param.rankTableFile + }; + if (this->param.mapping.isInitialized_) { + atb_speed::common::ParallelInfo parallelInfo = param.mapping.Get(base::WORD_EMBED_TP); + parallelInfo.InitCommDomain( + wordEmbeddingParam.tensorParallelInfo.hcommInfo, + wordEmbeddingParam.tensorParallelInfo.commDomain); + } + }; +} + +atb::Status DecoderModel::AddWordEmbedding() +{ + atb::Operation *op = nullptr; + + atb_speed::Model::Node wordEmbeddingNode; + atb_speed::common::WordEmbeddingParam wordEmbeddingParam; + this->SetWordEmbeddingParam(wordEmbeddingParam); + CHECK_OPERATION_STATUS_RETURN(atb_speed::common::WordEmbedding(wordEmbeddingParam, &op)); + wordEmbeddingNode.operation.reset(op); + wordEmbeddingNode.inTensors = { + &graph_.weightTensors.at(0), + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "input_ids")) + }; + wordEmbeddingNode.outTensors = { + &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, "hidden_states")) + }; + graph_.nodes.push_back(wordEmbeddingNode); + ATB_SPEED_LOG_DEBUG("[+] base wordEmbeddingNode"); + return atb::NO_ERROR; +} + +atb::Status DecoderModel::AddPositionalEmbedding() +{ + atb::Operation *op = nullptr; + atb_speed::Model::Node positionalEmbeddingGatherNode; + CHECK_OPERATION_STATUS_RETURN(atb_speed::common::PositionalEmbeddingGather(&op)); + positionalEmbeddingGatherNode.operation.reset(op); + positionalEmbeddingGatherNode.inTensors = { + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "positional_ids")), + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "cosine_table")), + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "sine_table")), + }; + positionalEmbeddingGatherNode.outTensors = { + &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, "cosine_embedding")), + &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, "sine_embedding")) + }; + graph_.nodes.push_back(positionalEmbeddingGatherNode); + ATB_SPEED_LOG_DEBUG("[+] base positionalEmbeddingGatherNode"); + return atb::NO_ERROR; +} + +atb::Status DecoderModel::AddSplitHiddenStates() +{ + atb::Operation *op = nullptr; + if (this->param.enableFlashComm) { + atb_speed::Model::Node splitHiddenStatesNode; + atb_speed::common::AclNNSplitWithSizeParam aclnnSplitWithSizeParam; + aclnnSplitWithSizeParam.num = this->param.worldSize; + op = new atb_speed::common::SplitWithSizeOperation("splitHiddenStatesNode", aclnnSplitWithSizeParam); + splitHiddenStatesNode.operation.reset(op); + splitHiddenStatesNode.inTensors = { + &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, "hidden_states"))}; + for (int i = 0; i < this->param.worldSize; i++) { + splitHiddenStatesNode.outTensors.emplace_back(&graph_.internalTensors.at(atb_speed::common::GetTensorIdx( + this->internalTensorMap, "hidden_states_rank" + std::to_string(i)))); + } + + ATB_SPEED_LOG_DEBUG("[+] splitHiddenStatesNode"); + graph_.nodes.push_back(splitHiddenStatesNode); + } + + return atb::NO_ERROR; +} + +atb::Status DecoderModel::CreateLayerOperation(atb::Operation **op, uint32_t layerId) +{ + LayerParam layerParam; + this->SetLayerParam(layerParam, layerId); + if (this->param.normType == RMS_NORM) { + DecoderLayer decoderLayer(layerParam); + CHECK_OPERATION_STATUS_RETURN(decoderLayer.BuildGraph(op)); + } else { + DecoderLayer decoderLayer(layerParam); + CHECK_OPERATION_STATUS_RETURN(decoderLayer.BuildGraph(op)); + } + return atb::NO_ERROR; +} + +atb::Status DecoderModel::AddSend() +{ + atb::Operation *op = nullptr; + atb_speed::Model::Node sendNode; + atb::infer::SendParam sendParam; + sendParam.rank = this->param.rank; + sendParam.rankSize = this->param.ppGroupSize * this->param.tpWorldSize; + sendParam.rankRoot = 0; + sendParam.destRank = this->param.nextPpRank; + sendParam.rankTableFile = this->param.rankTableFile; + sendParam.commDomain = "sendRecvDomain"; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(sendParam, &op)); + sendNode.operation.reset(op); + atb::Tensor *firstInTensor = + this->param.skipWordEmbedding + ? &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "input_embedding")) + : &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, "hidden_states")); + sendNode.inTensors = {firstInTensor}; + sendNode.outTensors = {}; + graph_.nodes.push_back(sendNode); + + return atb::NO_ERROR; +} + +atb::Status DecoderModel::AddRecv() +{ + atb::Operation *op = nullptr; + atb_speed::Model::Node receiveNode; + atb::infer::RecvParam recvParam; + recvParam.rank = this->param.rank; + recvParam.rankSize = this->param.ppGroupSize * this->param.tpWorldSize; + recvParam.rankRoot = 0; + recvParam.srcRank = this->param.prevPpRank; + recvParam.rankTableFile = this->param.rankTableFile; + recvParam.commDomain = "sendRecvDomain"; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(recvParam, &op)); + receiveNode.operation.reset(op); + atb::Tensor *firstInTensor = + this->param.skipWordEmbedding + ? &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "input_embedding")) + : &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, "hidden_states")); + receiveNode.inTensors = {firstInTensor}; + receiveNode.outTensors = {firstInTensor}; + graph_.nodes.push_back(receiveNode); + + return atb::NO_ERROR; +} + +atb::Status DecoderModel::AddLayer() +{ + uint32_t numHiddenLayers = param.enableDap ? param.numHiddenLayers * 2 : param.numHiddenLayers; + for (uint32_t layerId = 0; layerId < numHiddenLayers; ++layerId) { + if (param.enableDap && (layerId % 2 == 1)) { // 2: even layer + GetSingleton().SetRole(common::DapRole::SUCCESSOR); + } + uint32_t trueLayerId = param.enableDap ? layerId / 2 : layerId; + + uint32_t nodeCount = graph_.nodes.size(); + this->AddSingleLayer(trueLayerId); + for (uint32_t index = nodeCount; index < graph_.nodes.size(); index++) { + if (GetSingleton().GetRole() == common::DapRole::SUCCESSOR) { + CHECK_OPERATION_STATUS_RETURN(SetNodeStreamId(graph_.nodes.at(index), 1)); + ReplaceDapTensors(graph_.nodes.at(index).inTensors); + ReplaceDapTensors(graph_.nodes.at(index).outTensors); + } + } + + if (param.enableDap) { + GetSingleton().SetRole(common::DapRole::PRECEDER); + } + } + return atb::NO_ERROR; +} + +atb::Status DecoderModel::AddSingleLayer(uint32_t layerId) +{ + atb::Operation *op = nullptr; + auto attnIt = std::find(this->param.attnSkipLayerSet.begin(), this->param.attnSkipLayerSet.end(), layerId); + auto mlpIt = std::find(this->param.mlpSkipLayerSet.begin(), this->param.mlpSkipLayerSet.end(), layerId); + if (attnIt != this->param.attnSkipLayerSet.end() && mlpIt != this->param.mlpSkipLayerSet.end()) { + return atb::NO_ERROR; + } + atb_speed::Model::Node layerNode; + CHECK_OPERATION_STATUS_RETURN(this->CreateLayerOperation(&op, layerId)); + + layerNode.operation.reset(op); + layerNode.inTensors.resize(layerNode.operation->GetInputNum()); + layerNode.inTensorReshapeFuncs.resize(layerNode.operation->GetInputNum()); + SetLayerNodeInput(layerNode, layerId); + + if (this->param.hasAttnDp && this->param.hasMlpTp) { + layerNode.outTensors = { + layerNode.inTensors.at(weightCountPerLayer), + &graph_.internalTensors.at( + atb_speed::common::GetTensorIdx(this->internalTensorMap, "attn_dp_last_layer")) + }; + } else { + layerNode.outTensors = {layerNode.inTensors.at(weightCountPerLayer)}; // 输出原地写在输入上 + } + if (this->param.enableInterLayerAddNorm && (layerId != (param.numHiddenLayers - 1))) { + layerNode.outTensors.push_back( + &graph_.internalTensors.at( + atb_speed::common::GetTensorIdx(this->internalTensorMap, "last_layer_mlp_out") + ) + ); + } + graph_.nodes.push_back(layerNode); + ATB_SPEED_LOG_DEBUG("[+] add base layerNode num" << layerId); + return atb::NO_ERROR; +} + +void DecoderModel::SetLayerParallelismParam(LayerParam &layerParam) +{ + layerParam.backend = this->param.backend; + if (this->param.hasPp) { + layerParam.tensorParallelInfo = { + this->param.tpRank, this->param.tpWorldSize, this->param.backend, this->param.tpRankTableFile, \ + nullptr, this->param.tpDomain}; + } else { + layerParam.tensorParallelInfo = { + this->param.rank, this->param.worldSize, this->param.backend, this->param.rankTableFile, gHcommInfo}; + } + layerParam.hasAttnTp = this->param.hasAttnTp; + layerParam.attnTpRank = this->param.attnTpRank; + layerParam.attnTpSize = this->param.attnTpSize; + layerParam.attnTpDomain = this->param.attnTpDomain; + layerParam.attnTpRankTableFile = this->param.rankTableFile; + layerParam.hasAttnDp = this->param.hasAttnDp; + layerParam.attnDpRank = this->param.attnDpRank; + layerParam.attnDpSize = this->param.attnDpSize; + layerParam.attnDpDomain = this->param.attnDpDomain; + layerParam.attnDpRankTableFile = this->param.rankTableFile; + layerParam.hasMlpTp = this->param.hasMlpTp; + layerParam.mlpTpRank = this->param.mlpTpRank; + layerParam.mlpTpSize = this->param.mlpTpSize; + layerParam.mlpTpDomain = this->param.mlpTpDomain; + layerParam.mlpTpRankTableFile = this->param.rankTableFile; + layerParam.enableSwigluQuant = this->param.enableSwigluQuant; + layerParam.mapping = this->param.mapping; +} + +void DecoderModel::SetLayerParam(LayerParam &layerParam, uint32_t layerId) +{ + layerParam.layerId = layerId; + layerParam.numHiddenLayers = this->param.numHiddenLayers; + layerParam.isFA = this->param.isFA; + layerParam.isUnpadInputs = this->param.isUnpadInputs; + layerParam.isPrefill = this->param.isPrefill; + layerParam.isBF16 = this->param.isBF16; + layerParam.isEdgeHardware = this->param.isEdgeHardware; + layerParam.enableSwiGLU = this->param.enableSwiGLU; + layerParam.enableLcoc = this->param.enableLcoc; + layerParam.enableMC2 = this->param.enableMC2; + layerParam.enableSpeculate = this->param.enableSpeculate; + layerParam.enableCompressHead = this->param.enableCompressHead; + layerParam.enableOmniAttention = this->param.enableOmniAttention; + layerParam.useQKNorm = this->param.useQKNorm; + if (layerParam.enableOmniAttention) { + layerParam.isomnicompressed = this->param.patternMask[layerId]; + this->param.isomnicompressed = this->param.patternMask[layerId]; + } + layerParam.enableSplitFuse = this->param.enableSplitFuse; + layerParam.enableLora = this->param.enableLora; + layerParam.enablePreFetchWeight = this->param.enablePreFetchWeight; + layerParam.loraEnableGMM = this->param.loraEnableGMM; + layerParam.enableKvQuant = this->param.enableKvQuant; + layerParam.enableFA3 = this->param.enableFA3; + layerParam.kvQuantHasOffset = this->param.kvQuantHasOffset; + layerParam.enableReduceQuant = this->param.enableReduceQuant; + layerParam.enableInterLayerAddNorm = this->param.enableInterLayerAddNorm; + layerParam.enableIntraLayerAddNorm = this->param.enableIntraLayerAddNorm; + layerParam.enablePrefixCache = this->param.enablePrefixCache; + layerParam.attnBackend = this->param.attnBackend; + layerParam.matmulBackend = this->param.matmulBackend; + layerParam.positionEmbeddingType = this->param.positionEmbeddingType; + layerParam.normEps = this->param.normEps; + layerParam.normType = this->param.normType; + layerParam.quantGroupSize = this->param.quantGroupSize; + layerParam.numAttentionHeadsPerRank = this->param.numAttentionHeadsPerRank; + layerParam.hiddenSizePerAttentionHead = this->param.hiddenSizePerAttentionHead; + layerParam.numKeyValueHeadsPerRank = this->param.numKeyValueHeadsPerRank; + layerParam.enableFlashComm = this->param.enableFlashComm; + layerParam.enableModelConfuscation = this->param.enableModelConfuscation; + layerParam.modelConfuscationFd = this->param.modelConfuscationFd; + if (layerId != 0) { layerParam.enableModelConfuscation = false; } + if (!this->param.packQuantType.empty()) { + layerParam.packQuantType = this->param.packQuantType[layerId]; + } + if (!this->param.linearQuantType.empty()) { + layerParam.linearQuantType = this->param.linearQuantType[layerId]; + } + layerParam.linearTransposeType = this->param.linearTransposeType[layerId]; + if (!this->param.linearHasBias.empty()) { + layerParam.linearHasBias = this->param.linearHasBias[layerId]; + } + if (!this->param.linearDescs.empty()) { + layerParam.linearDescs = this->param.linearDescs[layerId]; + } + if (!this->param.isAntiOutlier.empty()) { + layerParam.isAntiOutlier = this->param.isAntiOutlier[layerId]; + } + layerParam.weightQuantType = this->param.weightQuantType; + SetLayerParallelismParam(layerParam); + if (!layerParam.isPrefill) { + auto attnIt = std::find(this->param.attnSkipLayerSet.begin(), this->param.attnSkipLayerSet.end(), layerId); + if (attnIt != this->param.attnSkipLayerSet.end()) { + layerParam.isAttnSkipLayer = true; + ATB_SPEED_LOG_DEBUG("Skip attention layer, layer id is " << layerId); + } + auto mlpIt = std::find(this->param.mlpSkipLayerSet.begin(), this->param.mlpSkipLayerSet.end(), layerId); + if (mlpIt != this->param.mlpSkipLayerSet.end()) { + layerParam.isMlpSkipLayer = true; + ATB_SPEED_LOG_DEBUG("Skip mlp layer, layer id is " << layerId); + } + } +} + +void DecoderModel::SetLayerNodeInput(atb_speed::Model::Node &layerNode, uint32_t layerId) +{ + uint32_t inTensorId = 0; + this->SetLayerNodeDefaultInput(layerNode, layerId, inTensorId); + this->SetLayerNodeOptionalInput(layerNode, layerId, inTensorId); +} + +void DecoderModel::SetLayerNodeOptionalInput( + atb_speed::Model::Node &layerNode, uint32_t layerId, uint32_t &inTensorId) +{ + if (this->param.enableCompressHead) { + this->SetLayerNodeRaInput(layerNode, layerId, inTensorId); + } + if (this->param.enableOmniAttention) { + this->SetLayerNodeOmniInput(layerNode, layerId, inTensorId); + } + if (this->param.enableSpeculate || this->param.enableSplitFuse) { + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "q_len")); + } + if (this->param.enableLora) { + this->SetLayerNodeLoraInput(layerNode, layerId, inTensorId); + } + if (this->param.hasAttnDp) { + this->SetLayerNodeAttnDpInput(layerNode, inTensorId); + } + if (param.enableInterLayerAddNorm && layerId != 0) { + layerNode.inTensors.at(inTensorId++) = \ + &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, "last_layer_mlp_out")); + } + if (this->param.enableFlashComm) { + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "send_counts")); + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "sdispls")); + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "send_count")); + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "recv_counts")); + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "rdispls")); + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "recv_count")); + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "fake_rs_shape")); + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "fake_ag_shape")); + } +} + +void DecoderModel::SetLayerNodeDefaultInput( + atb_speed::Model::Node &layerNode, uint32_t layerId, uint32_t &inTensorId) +{ + for (uint32_t weightTensorId = 0; weightTensorId < this->weightCountPerLayer; ++weightTensorId) { + layerNode.inTensors.at(inTensorId++) = &graph_.weightTensors.at( + CheckIntMulOverFlow(layerId, this->weightCountPerLayer) + weightTensorId + this->weightCountWordEmbedding); + } + if (this->param.enableFlashComm) { + layerNode.inTensors.at(inTensorId++) = &graph_.internalTensors.at( + atb_speed::common::GetTensorIdx(this->internalTensorMap, + "hidden_states_rank" + std::to_string(this->param.rank))); + } else { + layerNode.inTensors.at(inTensorId++) = this->param.skipWordEmbedding ? \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "input_embedding")) : \ + &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, "hidden_states")); + } + if (this->param.positionEmbeddingType == atb_speed::base::PositionEmbeddingType::ROPE) { + layerNode.inTensors.at(inTensorId++) = &graph_.internalTensors.at( + atb_speed::common::GetTensorIdx(this->internalTensorMap, "cosine_embedding")); + layerNode.inTensors.at(inTensorId++) = &graph_.internalTensors.at( + atb_speed::common::GetTensorIdx(this->internalTensorMap, "sine_embedding")); + } else { + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at( + atb_speed::common::GetTensorIdx(this->inTensorMap, "place_holder")); + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at( + atb_speed::common::GetTensorIdx(this->inTensorMap, "place_holder")); + } + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "attention_mask")); + layerNode.inTensors.at(inTensorId++) = &graph_.kCacheTensors.at(layerId); + layerNode.inTensors.at(inTensorId++) = &graph_.vCacheTensors.at(layerId); + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "seq_len")); + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "token_offset")); + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "kv_cache_idx")); + if (this->param.enableCompressHead && this->param.positionEmbeddingType == PositionEmbeddingType::ROPE) { + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "layer_" + std::to_string(layerId) + "_" + "ra_block_tables")); + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "layer_" + std::to_string(layerId) + "_" + "ra_slots")); + } else if (this->param.enableOmniAttention) { + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "layer_" + std::to_string(layerId) + "_" + "ra_block_tables")); + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "layer_" + std::to_string(layerId) + "_" + "ra_slots")); + } else { + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "block_tables")); + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "slots")); + } +} + +void DecoderModel::SetLayerNodeRaInput(atb_speed::Model::Node &layerNode, uint32_t layerId, uint32_t &inTensorId) +{ + if (param.positionEmbeddingType == PositionEmbeddingType::ALIBI) { + for (std::string raInputName: this->inTensorCandidates.at("compress_head_alibi")) { + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at( + atb_speed::common::GetTensorIdx(this->inTensorMap, raInputName)); + } + } else if (param.positionEmbeddingType == PositionEmbeddingType::ROPE) { + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "wins_global")); + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "layer_" + std::to_string(layerId) + "_" + "in_ra_seqlens")); + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "layer_" + std::to_string(layerId) + "_" + "pffset_index")); + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "layer_" + std::to_string(layerId) + "_" + "razor_offset")); + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "in_reshape_seqlen")); + } +} + +void DecoderModel::SetLayerNodeOmniInput(atb_speed::Model::Node &layerNode, uint32_t layerId, uint32_t &inTensorId) +{ + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "wins_global")); + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "layer_" + std::to_string(layerId) + "_" + "in_ra_seqlens")); + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "layer_" + std::to_string(layerId) + "_" + "pffset_index")); + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "layer_" + std::to_string(layerId) + "_" + "razor_offset")); + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "in_reshape_seqlen")); +} + +void DecoderModel::SetLayerNodeLoraInput(atb_speed::Model::Node &layerNode, uint32_t layerId, uint32_t &inTensorId) +{ + layerNode.inTensors.at(inTensorId++) = \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "seq_len_cum_sum")); + for (std::string loraWeightName : this->inTensorCandidates.at("lora_per_layer")) { + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at( + atb_speed::common::GetTensorIdx( + this->inTensorMap, "layer_" + std::to_string(layerId) + loraWeightName) + ); + } +} + +void DecoderModel::SetLayerNodeAttnDpInput(atb_speed::Model::Node &layerNode, uint32_t &inTensorId) +{ + for (std::string attnDpInputName : this->inTensorCandidates.at("attn_dp")) { + layerNode.inTensors.at(inTensorId++) = &graph_.inTensors.at( + atb_speed::common::GetTensorIdx(this->inTensorMap, attnDpInputName) + ); + } + ATB_SPEED_LOG_DEBUG("decoder model has pushed up atten_dp intensors"); +} + +void DecoderModel::SetFinalNormParam(atb::infer::RmsNormParam &normParam) +{ + normParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM; + normParam.normParam.epsilon = this->param.normEps; +} + +void DecoderModel::SetFinalNormParam(atb::infer::LayerNormParam &normParam) +{ + int32_t beginParamsAxis = this->param.isFA ? 2 : 1; + normParam.layerType = atb::infer::LayerNormParam::LAYER_NORM_NORM; + normParam.normParam.epsilon = this->param.normEps; + normParam.normParam.beginNormAxis = beginParamsAxis; + normParam.normParam.beginParamsAxis = 1; +} + +atb::Status DecoderModel::AddFinalNorm() +{ + atb::Operation *op = nullptr; + + atb_speed::Model::Node finalNormNode; + if (this->param.normType == NormType::RMS_NORM) { + atb::infer::RmsNormParam finalNormParam; + this->SetFinalNormParam(finalNormParam); + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(finalNormParam, &op)); + } else { + atb::infer::LayerNormParam finalNormParam; + this->SetFinalNormParam(finalNormParam); + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(finalNormParam, &op)); + } + finalNormNode.operation.reset(op); + const uint32_t finalLayerNormWeightTensorId = + graph_.weightTensors.size() - this->weightCountFinalNorm - this->weightCountLmHead; + if (this->param.hasAttnDp && this->param.hasMlpTp) { + finalNormNode.inTensors = { + this->param.skipWordEmbedding ? \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "input_embedding")) : \ + &graph_.internalTensors.at(atb_speed::common::GetTensorIdx( + this->internalTensorMap, "attn_dp_last_layer")), + &graph_.weightTensors.at(finalLayerNormWeightTensorId)}; + } else if (this->param.enableFlashComm) { + finalNormNode.inTensors = { &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, + "hidden_states_rank" + std::to_string(this->param.rank))), + &graph_.weightTensors.at(finalLayerNormWeightTensorId)}; + } else { + finalNormNode.inTensors = { + this->param.skipWordEmbedding ? \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "input_embedding")) : \ + &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, "hidden_states")), + &graph_.weightTensors.at(finalLayerNormWeightTensorId)}; + } + if (this->param.normType == NormType::LAYER_NORM) { + finalNormNode.inTensors.push_back(&graph_.weightTensors.at(finalLayerNormWeightTensorId + 1)); + } + finalNormNode.outTensors = {finalNormNode.inTensors.at(0)}; // 输出原地写在输入上 + graph_.nodes.push_back(finalNormNode); + ATB_SPEED_LOG_DEBUG("[+] base finalNormNode"); + return atb::NO_ERROR; +} + +atb::Status DecoderModel::AddAllGather() +{ + atb::Operation *op = nullptr; + if (this->param.enableFlashComm) { + atb_speed::Model::Node allGatherVNode; + atb::infer::AllGatherVParam allGatherVParam; + allGatherVParam.rank = this->param.rank; + allGatherVParam.rankSize = this->param.worldSize; + allGatherVParam.backend = this->param.backend; + if (this->param.mapping.isInitialized_) { + atb_speed::common::ParallelInfo parallelInfo = param.mapping.Get(base::LM_HEAD_TP); + parallelInfo.InitCommDomain(allGatherVParam.hcclComm, allGatherVParam.commDomain); + } + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(allGatherVParam, &op)); + allGatherVNode.operation.reset(op); + allGatherVNode.inTensors = {&graph_.internalTensors.at(atb_speed::common::GetTensorIdx( + this->internalTensorMap, "hidden_states_rank" + std::to_string(param.rank)))}; + allGatherVNode.inTensors.emplace_back( + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "recv_count"))); + allGatherVNode.inTensors.emplace_back( + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "send_counts"))); + allGatherVNode.inTensors.emplace_back( + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "sdispls"))); + allGatherVNode.inTensors.emplace_back( + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "fake_ag_shape"))); + allGatherVNode.outTensors = {&graph_.internalTensors.at(atb_speed::common::GetTensorIdx( + this->internalTensorMap, "hidden_states_allgather"))}; + ATB_SPEED_LOG_DEBUG("[+] allGatherVNode"); + graph_.nodes.push_back(allGatherVNode); + } + return atb::NO_ERROR; +} + +void DecoderModel::SetLmHeadParam(atb_speed::common::LmHeadParam &lmHeadParam) +{ + lmHeadParam.unpadInputs = this->param.isUnpadInputs; + lmHeadParam.gatherAhead = this->param.isPrefill || this->param.enablePrefixCache || this->param.hasAttnDp; + lmHeadParam.hiddenSizePerAttentionHead = this->param.hiddenSizePerAttentionHead; + lmHeadParam.linearParallelParam.fusionLinearParam.isBF16 = this->param.isBF16; + lmHeadParam.linearParallelParam.fusionLinearParam.transposeType = this->param.lmHeadTransposeType; + lmHeadParam.linearParallelParam.fusionLinearParam.matmulBackend = param.matmulBackend; + lmHeadParam.linearParallelParam.unpadInputs = !this->param.isFA; + lmHeadParam.linearParallelParam.enableMC2 = this->param.enableMC2; + lmHeadParam.linearParallelParam.isArgmaxlogits = this->param.enableGreedyPostProcessing; + lmHeadParam.linearParallelParam.worldSize = this->param.worldSize; + if (this->param.isLmHeadParallel && this->param.hasPp) { + lmHeadParam.linearParallelParam.parallelType = atb_speed::common::COLUMN_PARALLEL; + lmHeadParam.linearParallelParam.tensorParallelInfo = { + this->param.tpRank, this->param.tpWorldSize, this->param.backend, this->param.tpRankTableFile, \ + gHcommInfo, this->param.tpDomain}; + } else if (this->param.isLmHeadParallel && this->param.hasAttnDp && this->param.hasMlpTp) { + lmHeadParam.linearParallelParam.parallelType = atb_speed::common::COLUMN_PARALLEL; + lmHeadParam.linearParallelParam.tensorParallelInfo = { + this->param.mlpTpRank, this->param.mlpTpSize, this->param.backend, this->param.mlpTpRankTableFile, + gHcommInfo, this->param.mlpTpDomain}; + } else if (this->param.isLmHeadParallel) { + lmHeadParam.linearParallelParam.parallelType = atb_speed::common::COLUMN_PARALLEL; + lmHeadParam.linearParallelParam.tensorParallelInfo = { + this->param.rank, this->param.worldSize, this->param.backend, this->param.rankTableFile, gHcommInfo}; + + if (this->param.mapping.isInitialized_) { + atb_speed::common::ParallelInfo parallelInfo = param.mapping.Get(base::LM_HEAD_TP); + parallelInfo.InitCommDomain( + lmHeadParam.linearParallelParam.tensorParallelInfo.hcommInfo, + lmHeadParam.linearParallelParam.tensorParallelInfo.commDomain); + } + } +} + +atb::Status DecoderModel::AddLmhead() +{ + atb::Operation *op = nullptr; + + atb_speed::Model::Node lmHeadNode; + atb_speed::common::LmHeadParam lmHeadParam; + this->SetLmHeadParam(lmHeadParam); + CHECK_OPERATION_STATUS_RETURN(LmHead(lmHeadParam, &op)); + lmHeadNode.operation.reset(op); + const uint64_t finalLinearWeightTensorId = graph_.weightTensors.size() - this->weightCountLmHead; + uint32_t placeHolderIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, "place_holder"); + + if (this->param.hasAttnDp && this->param.hasMlpTp) { + lmHeadNode.inTensors = { + this->param.skipWordEmbedding ? \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "input_embedding")) : \ + &graph_.internalTensors.at( + atb_speed::common::GetTensorIdx(this->internalTensorMap, "attn_dp_last_layer")) + }; + } else if (this->param.enableFlashComm) { + lmHeadNode.inTensors = { &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, + "hidden_states_allgather"))}; + } else { + lmHeadNode.inTensors = { + this->param.skipWordEmbedding ? \ + &graph_.inTensors.at(atb_speed::common::GetTensorIdx(this->inTensorMap, "input_embedding")) : \ + &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, "hidden_states")), + }; + } + // shape: [vocabSizePerRank, hiddenSize] + lmHeadNode.inTensors.push_back(&graph_.weightTensors.at(finalLinearWeightTensorId)); + // LmHead未接入量化,量化权重使用placeholder代替 + lmHeadNode.inTensors.push_back(&graph_.inTensors.at(placeHolderIdx)); + lmHeadNode.inTensors.push_back(&graph_.inTensors.at(placeHolderIdx)); + lmHeadNode.inTensors.push_back(&graph_.inTensors.at(placeHolderIdx)); + lmHeadNode.inTensors.push_back(&graph_.inTensors.at(placeHolderIdx)); + lmHeadNode.inTensors.push_back(&graph_.inTensors.at(placeHolderIdx)); + lmHeadNode.inTensors.push_back(&graph_.inTensors.at( + atb_speed::common::GetTensorIdx(this->inTensorMap, "logits_indices"))); + if (this->param.enableGreedyPostProcessing) { + lmHeadNode.inTensors.emplace_back(&graph_.inTensors.at(atb_speed::common::GetTensorIdx( + this->inTensorMap, "logits_offset_tensor"))); + } else { + lmHeadNode.inTensors.emplace_back(&graph_.inTensors.at(placeHolderIdx)); + } + // shpae: FA: [batchSize, seqLen, vocabSize] PA: [seqLen, vocabSize] + lmHeadNode.outTensors = {&graph_.outTensors.at(atb_speed::common::GetTensorIdx(this->outTensorMap, "logits"))}; + if (this->param.enableFlashComm) { + lmHeadNode.inTensorReshapeFuncs.resize(lmHeadNode.inTensors.size()); + lmHeadNode.inTensorReshapeFuncs.at(0) = &atb_speed::common::SqueezeBatchAndSeq; + } + graph_.nodes.push_back(lmHeadNode); + ATB_SPEED_LOG_DEBUG("[+] base lmHeadNode"); + return atb::NO_ERROR; +} + + +void DecoderModel::ParseDapParam(nlohmann::json ¶mJson) +{ + this->seqLenForDap.Parse("seqLen", paramJson); + this->tokenOffsetForDap.Parse("tokenOffset", paramJson); + this->qLenForDap.Parse("qLen", paramJson); +} + +void DecoderModel::ParseParallelParam(nlohmann::json ¶mJson) +{ + if (paramJson.contains("seqLenCp")) { + this->seqLenCp.Parse("seqLenCp", paramJson); + } + if (paramJson.contains("seqLenSp")) { + this->seqLenSp.Parse("seqLenSp", paramJson); + } +} + +void DecoderModel::ParseFlashCommParam(nlohmann::json ¶mJson) +{ + if (param.enableFlashComm) { + this->sendCounts.Parse("sendCounts", paramJson); + this->sdispls.Parse("sdispls", paramJson); + this->sendCount.Parse("sendCount", paramJson); + this->recvCounts.Parse("recvCounts", paramJson); + this->rdispls.Parse("rdispls", paramJson); + this->recvCount.Parse("recvCount", paramJson); + } +} + +atb::Status DecoderModel::ParseParam(const std::string ¶mString) +{ + CHECK_PARAM_LT(paramString.size(), MAX_PARAM_STRING_LENGTH); + nlohmann::json paramJson = StringToJson(paramString); + + // Dap use dynamicParam to store all params instead of using attribute to store single param + if (param.enableDap) { + ParseDapParam(paramJson); + return atb::NO_ERROR; + } + + this->tokenOffset.clear(); + for (auto item : paramJson["tokenOffset"]) { + this->tokenOffset.push_back(item.get()); + ATB_SPEED_LOG_DEBUG("token offset value: " << item); + } + + if (param.enableOmniAttention) { + this->seqLenTmp.clear(); + for (auto item : paramJson["seqLen"]) { + this->seqLenTmp.push_back(item.get()); + ATB_SPEED_LOG_DEBUG("seqLen value: " << item); + } + this->seqLen.clear(); + ExpandVectorToN(this->seqLenTmp, this->seqLen, param.numHiddenLayers); + } else { + this->seqLen.clear(); + for (auto item : paramJson["seqLen"]) { + this->seqLen.push_back(item.get()); + ATB_SPEED_LOG_DEBUG("seqLen value: " << item); + } + } + + this->qLen.clear(); + for (auto item : paramJson["qLen"]) { + this->qLen.push_back(item.get()); + ATB_SPEED_LOG_DEBUG("qLen value: " << item); + } + + ParseParallelParam(paramJson); + + ParseFlashCommParam(paramJson); + if (param.enablePrefixCache) { + this->ringCurSeqlen.Parse("ringCurSeqlen", paramJson); + this->ringCacheSeqlen.Parse("ringCacheSeqlen", paramJson); + } + return atb::NO_ERROR; +} + +void DecoderModel::BindDapHostTensor(DynamicParam>& dynamicParam, std::string tensorName) +{ + uint32_t tensorIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, tensorName); + if (tensorIdx != UINT32_MAX) { + graph_.inTensors.at(tensorIdx).hostData = dynamicParam.Get().data(); + } + if (param.enableDap) { + GetSingleton().SetRole(common::DapRole::SUCCESSOR); + std::string suffix = GetSingleton().GetSuccessorSuffix(); + tensorIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, tensorName + suffix); + if (tensorIdx != UINT32_MAX) { + graph_.inTensors.at(tensorIdx).hostData = dynamicParam.Get().data(); + } + GetSingleton().SetRole(common::DapRole::PRECEDER); + } +} + +void DecoderModel::BindFlashcommHostTensor() +{ + std::map>&> bindFlashCommTensorMap = { + {"send_counts", this->sendCounts}, + {"sdispls", this->sdispls}, + {"send_count", this->sendCount}, + {"recv_counts", this->recvCounts}, + {"rdispls", this->rdispls}, + {"recv_count", this->recvCount}, + }; + uint32_t tensorIdx = 0; + for (auto item : bindFlashCommTensorMap) { + tensorIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, item.first); + if (tensorIdx != UINT32_MAX) { + graph_.inTensors.at(tensorIdx).hostData = item.second.Get().data(); + } + } +} + +atb::Status DecoderModel::BindParamHostTensor(uint32_t nodeId) +{ + ATB_SPEED_LOG_DEBUG("BindParamHostTensor nodeId = " << nodeId); + + if (nodeId != 0) { + // 仅需在graph的intensor中bind一次 + return atb::NO_ERROR; + } + + if (param.enableDap) { + BindDapHostTensor(this->seqLenForDap, "seq_len"); + BindDapHostTensor(this->tokenOffsetForDap, "token_offset"); + BindDapHostTensor(this->qLenForDap, "q_len"); + return atb::NO_ERROR; + } + + uint32_t tensorIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, "token_offset"); + if (tensorIdx != UINT32_MAX) { + graph_.inTensors.at(tensorIdx).hostData = this->tokenOffset.data(); + } + + tensorIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, "seq_len"); + if (tensorIdx != UINT32_MAX) { + graph_.inTensors.at(tensorIdx).hostData = this->seqLen.data(); + } + + tensorIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, "q_len"); + if (tensorIdx != UINT32_MAX) { + graph_.inTensors.at(tensorIdx).hostData = this->qLen.data(); + } + if (param.enableFlashComm) { + BindFlashcommHostTensor(); + } + if (param.enablePrefixCache) { + tensorIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, "ring_cur_seqlen"); + if (tensorIdx != UINT32_MAX) { + graph_.inTensors.at(tensorIdx).hostData = this->ringCurSeqlen.Get().data(); + } + tensorIdx = atb_speed::common::GetTensorIdx(this->inTensorMap, "ring_cache_seqlen"); + if (tensorIdx != UINT32_MAX) { + graph_.inTensors.at(tensorIdx).hostData = this->ringCacheSeqlen.Get().data(); + } + } + ATB_SPEED_LOG_DEBUG("BindParamHostTensor end"); + return atb::NO_ERROR; +} + +/** + * generate the seq length for each layer of each batch, the result should look like + * [layer, batch] = compressed_head ? min(384, seq_len) :seq_len + */ +atb::Status DecoderModel::ExpandVectorToN(const std::vector& input, std::vector& output, uint32_t layernums) +{ + // the number of active batches + size_t batchSize = input.size(); + // Check if input vector is empty or N is 0 + if (layernums == 0 || batchSize == 0) { + throw std::invalid_argument("Input vector cannot be empty"); + } + for (size_t layer = 0; layer < layernums; layer++) { + for (size_t batch = 0; batch < batchSize; batch++) { + bool isCompressed = param.patternMask[layer] == 1; + int batchSeqlen = input[batch]; + const int omniLimitSeqLen = 384; + isCompressed ? output.push_back(std::min(batchSeqlen, omniLimitSeqLen)) : + output.push_back(batchSeqlen); + } + } + + return atb::NO_ERROR; +} + +} // namespace base +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/base/model/decoder_model.h b/tests/proftest/layer_test_framework/models/base/model/decoder_model.h new file mode 100644 index 00000000..66c0b2bf --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/model/decoder_model.h @@ -0,0 +1,264 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_BASE_DECODER_MODEL_H +#define ATB_SPEED_MODELS_BASE_DECODER_MODEL_H + +#include +#include "atb_speed/base/model.h" +#include "atb_speed/base/external_comm_manager.h" +#include "models/base/param/model_param.h" +#include "models/base/param/dynamic_param.h" +#include "models/base/layer/decoder_layer.h" +#include "operations/fusion/embedding/word_embedding.h" +#include "operations/fusion/embedding/positional_embedding.h" +#include "operations/fusion/lmhead/lmhead.h" +#include "operations/fusion/utils.h" +#include "atb_speed/utils/singleton.h" +#include "atb_speed/utils/tensor_util.h" + +namespace atb_speed { +namespace base { + +/// Base class for large language models, inherited from the `Model` class. +class DecoderModel : public Model { +public: + explicit DecoderModel(const std::string ¶m); + ~DecoderModel() override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + atb::Status InferShape(const std::vector &inTensorDescs, + std::vector &outTensorDescs) override; + +protected: + /// Select input tensors from all the available input tensor candidates based on the provided parameters `param`. + /// Automatically assign an index to each tensor. + /// The order of the input tensors should align with the sequence passed from the Python side. + virtual void ConstructInTensorMap(); + /// Select internal tensors from all the available internal tensor candidates + /// based on the provided parameters `param`. Automatically assign an index to each tensor. + virtual void ConstructInternalTensorMap(); + /// Select output tensors from all the available output tensor candidates + /// based on the provided parameters `param`. Automatically assign an index to each tensor. + virtual void ConstructOutTensorMap(); + /// Return the total number of weight tensors. It can differ depending on various features. + virtual uint32_t CalcWeightTensorSize(); + atb::Status ParseParam(const std::string ¶mString) override; + virtual void ParseDapParam(nlohmann::json ¶mJson); + virtual void ParseFlashCommParam(nlohmann::json ¶mJson); + virtual void ParseParallelParam(nlohmann::json ¶mJson); + atb::Status BindParamHostTensor(uint32_t nodeId) override; + void BindDapHostTensor(DynamicParam>& dynamicParam, std::string tensorName); + void BindFlashcommHostTensor(); + /// Update the `wordEmbeddingParam` using the values from `param`. + /// \param wordEmbeddingParam an `WordEmbeddingParam` object that needs to be updated + virtual void SetWordEmbeddingParam(atb_speed::common::WordEmbeddingParam &wordEmbeddingParam); + /// Update the `layerParam` using the values from `param`. + /// \param layerParam an `LayerParam` object that needs to be updated + /// \param layerId the index of the current layer + virtual void SetLayerParam(LayerParam &layerParam, uint32_t layerId); + /// Update the `layerParam` using the values from `param`. + /// \param layerParam an `LayerParam` object that needs to be updated + virtual void SetLayerParallelismParam(LayerParam &layerParam); + /// Update the `normParam` using the values from `param`. + /// \param normParam an `RmsNormParam` object that needs to be updated + virtual void SetFinalNormParam(atb::infer::RmsNormParam &normParam); + /// Update the `normParam` using the values from `param`. + /// \param normParam an `LayerNormParam` object that needs to be updated + virtual void SetFinalNormParam(atb::infer::LayerNormParam &normParam); + /// Update the `lmHeadParam` using the values from `param`. + /// \param lmHeadParam an `LmHeadParam` object that needs to be updated + virtual void SetLmHeadParam(atb_speed::common::LmHeadParam &lmHeadParam); + /// The main entrance to set `layerNode`'s input tensors. + /// It will call `SetLayerNodeDefaultInput` and `SetLayerNodeOptionalInput`. + /// \param layerNode an `atb_speed::Model::Node` object that needs to be updated + /// \param layerId the index of the current layer + virtual void SetLayerNodeInput(atb_speed::Model::Node &layerNode, uint32_t layerId); + /// Set `layerNode`'s default input tensors based on the values from `param`. + /// \param layerNode an `atb_speed::Model::Node` object that needs to be updated + /// \param layerId the index of the current layer + /// \param inTensorId the starting input tensor IDs + virtual void SetLayerNodeDefaultInput(atb_speed::Model::Node &layerNode, uint32_t layerId, uint32_t &inTensorId); + /// Set `layerNode`'s optional input tensors based on the values from `param`. + /// \param layerNode an `atb_speed::Model::Node` object that needs to be updated + /// \param layerId the index of the current layer + /// \param inTensorId the starting input tensor IDs + virtual void SetLayerNodeOptionalInput(atb_speed::Model::Node &layerNode, uint32_t layerId, uint32_t &inTensorId); + /// Set `layerNode`'s input tensors for razor attention based on the values from `param`. + /// It will be called by `SetLayerNodeOptionalInput`. + /// \param layerNode an `atb_speed::Model::Node` object that needs to be updated + /// \param layerId the index of the current layer + /// \param inTensorId the starting input tensor IDs + virtual void SetLayerNodeRaInput(atb_speed::Model::Node &layerNode, uint32_t layerId, uint32_t &inTensorId); + /// Set `layerNode`'s input tensors for omni attention based on the values from `param`. + /// It will be called by `SetLayerNodeOptionalInput`. + /// \param layerNode an `atb_speed::Model::Node` object that needs to be updated + /// \param layerId the index of the current layer + /// \param inTensorId the starting input tensor IDs + virtual void SetLayerNodeOmniInput(atb_speed::Model::Node &layerNode, uint32_t layerId, uint32_t &inTensorId); + /// Set `layerNode`'s input tensors for multi lora based on the values from `param`. + /// It will be called by `SetLayerNodeOptionalInput`. + /// \param layerNode an `atb_speed::Model::Node` object that needs to be updated + /// \param layerId the index of the current layer + /// \param inTensorId the starting input tensor IDs + virtual void SetLayerNodeLoraInput(atb_speed::Model::Node &layerNode, uint32_t layerId, uint32_t &inTensorId); + /// Set `layerNode`'s input tensors for data parallalism in the attention module based on the values from `param`. + /// It will be called by `SetLayerNodeOptionalInput`. + /// \param layerNode an `atb_speed::Model::Node` object that needs to be updated + /// \param inTensorId the starting input tensor IDs + virtual void SetLayerNodeAttnDpInput(atb_speed::Model::Node &layerNode, uint32_t &inTensorId); + /// Create an `LayerParam` object, call `SetLayerParam` and + /// call `DecoderLayer`'s `buildGraph` function to create an operation. + /// \param op the address of a pointer to a default operation + /// \param layerId the index of the current layer + /// \return A flag indicates whether the operation was successfully created. + virtual atb::Status CreateLayerOperation(atb::Operation **op, uint32_t layerId); + /// Add a word embedding node to the graph to convert token ids to embedding. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddWordEmbedding(); + /// Add a positional embedding node to the graph to convert positional ids to embedding. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddPositionalEmbedding(); + /// Add all layer nodes to the graph. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddLayer(); + /// Add a layer node to the graph. + /// Overriding this function is advised in order to easily adapt the model to DAP. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddSingleLayer(uint32_t layerId); + /// Add a normalization node to the graph. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddFinalNorm(); + /// Add a lmhead node to the graph. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddLmhead(); + /// Add a send node to the graph to send hidden states to the next pipeline parallalism stage. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddSend(); + /// Add a receive node to the graph to receive hidden states to the next pipeline parallalism stage. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddRecv(); + /// Add all the operations before layer node. + /// Overriding this function is advised in order to easily adapt the model to DAP. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddNodesBeforeLayer(); + /// Add all the operations after layer node. + /// Overriding this function is advised in order to easily adapt the model to DAP. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddNodesAfterLayer(); + /// The primary entry point for adding all operations to the graph in sequential order. + /// It is recommended to override `AddNodesBeforeLayer`, `AddNodesAfterLayer`, and `AddSingleLayer` + /// instead of this funciton. + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddOperationToGraph(); + /// add a split hiddenstate node for flashcomm + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddSplitHiddenStates(); + /// add a allgather hiddenstate node for flashcomm + /// \return A flag indicates whether the operation was successfully added to the graph. + virtual atb::Status AddAllGather(); + /// for flashcomm 1.0 + /// Data sent for every card + DynamicParam> sendCounts; + /// It signifies that the current rank commences sent data from rank i + /// at an offset of n with respect to the input's initiation point + DynamicParam> sdispls; + /// Data sent by this card + DynamicParam> sendCount; + /// Data received for every card + DynamicParam> recvCounts; + /// It signifies that the current rank commences receiving data from rank i + /// at an offset of n with respect to the input's initiation point + DynamicParam> rdispls; + /// Data sent by this card + DynamicParam> recvCount; + + int64_t BuildGraph() override; + /// omni-attention + atb::Status ExpandVectorToN(const std::vector& input, std::vector& output, uint32_t N); + + /// Parameters that will change during each forward pass. + /// The position ID of the current token in the inference sequence. + /// Number of elements should be equal to batch size. + std::vector tokenOffset = {}; + /// The total number of input and output tokens. + /// In the prefill phase, each elements equals to the length of the prompt. + /// For flash attention, each element is set to 1 in the decode phase. + /// For paged attention, each element is set to the number of input tokens plus output tokens in the decode phase. + /// Number of elements should be equal to batch size. + std::vector seqLen = {}; + // for omni_attention + std::vector seqLenTmp = {}; + /// Number of input tokens for the current forward pass. + /// Number of elements should be equal to batch size. + std::vector qLen = {}; + // For Attn Inner SP + DynamicParam> seqLenSp; + // For Attn CP + DynamicParam> seqLenCp; + // For prefixcache ring_cur + ring_cache + DynamicParam> ringCurSeqlen = {}; + DynamicParam> ringCacheSeqlen = {}; + + // For Dap + DynamicParam> seqLenForDap; + DynamicParam> tokenOffsetForDap; + DynamicParam> qLenForDap; + + /// Specifies all potential input tensors, where the key represents the feature name, + /// and the value corresponds to the input tensor name. + std::map> inTensorCandidates = {}; + /// Specifies all potential internal tensors, where the key represents the feature name, + /// and the value corresponds to the internal tensor name. + std::map> internalTensorCandidates = {}; + /// Specifies all potential output tensors, where the key represents the feature name, + /// and the value corresponds to the output tensor name. + std::map> outTensorCandidates = {}; + /// Defines all the required input tensors for the current graph, with the key representing the input tensor name + /// and the value corresponding to the tensor index. + std::map inTensorMap = {}; + /// Defines all the required internal tensors for the current graph, + /// with the key representing the input tensor name + /// and the value corresponding to the tensor index. + std::map internalTensorMap = {}; + /// Defines all the required output tensors for the current graph, with the key representing the input tensor name + /// and the value corresponding to the tensor index. + std::map outTensorMap = {}; + /// Number of weights per layer + uint32_t weightCountPerLayer = 50; + /// Number of weights for the word embedding node + uint32_t weightCountWordEmbedding = 1; + /// Number of weights for the normalization node + uint32_t weightCountFinalNorm = 1; + /// Number of weights for the lmhead node + uint32_t weightCountLmHead = 1; + // Pointer of hccl communication domain for mc2 + static HcclComm gHcommInfo; + + /// Model parameters + ModelParam param; + +private: + void PrintTensorMapInfo(std::map &tensorMap) const; + void DuplicateTensorMapForDap( + std::map &tensorMap, std::vector &targetTensors); + std::map CopyMapWithSuffix(std::map& tensorMap) const; + std::map precederToSuccessorTensorMap = {}; + void ReplaceDapTensors(std::vector& tensors); +}; + +} // namespace base +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/models/base/param/dynamic_param.h b/tests/proftest/layer_test_framework/models/base/param/dynamic_param.h new file mode 100644 index 00000000..e552030f --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/param/dynamic_param.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_DYNAMIC_PARAM_H +#define ATB_SPEED_DYNAMIC_PARAM_H +#include +#include "atb_speed/utils/singleton.h" +#include "operations/fusion/utils.h" +#include "models/base/param/param_utils.h" + +namespace atb_speed { +namespace base { + +template +class DynamicParam { +public: + void Parse(std::string name, nlohmann::json ¶mJson) + { + this->enableDap_ = false; // reset + + this->name_ = name; + if (!paramJson.contains(name)) { return; } + this->data_ = FetchJsonParam(paramJson, name); + + std::string suffix = GetSingleton().GetSuccessorSuffix(); + if (!paramJson.contains(name + suffix)) { return; } + this->enableDap_ = true; + GetSingleton().SetRole(common::DapRole::SUCCESSOR); + this->successorData_ = FetchJsonParam(paramJson, name + suffix); + GetSingleton().SetRole(common::DapRole::PRECEDER); + } + + T& Get() + { + common::DapRole role = GetSingleton().GetRole(); + return role == common::DapRole::SUCCESSOR ? this->successorData_ : this->data_; + } + + std::string name_ = ""; + +private: + T data_; + T successorData_; + bool enableDap_ = false; +}; +} // namespace base +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/base/param/layer_param.cpp b/tests/proftest/layer_test_framework/models/base/param/layer_param.cpp new file mode 100644 index 00000000..b1fb48d7 --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/param/layer_param.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "models/base/param/layer_param.h" + +namespace atb_speed { +namespace base { + +void LayerParam::PrintParam() +{ + Param::PrintParam(); + std::stringstream ss; + ss << "Base Layer Param:" + << ", tensorParallelInfo.rank:" << this->tensorParallelInfo.rank + << ", tensorParallelInfo.worldSize:" << this->tensorParallelInfo.worldSize + << ", tensorParallelInfo.backend:" << this->tensorParallelInfo.backend + << ", tensorParallelInfo.rankTableFile:" << this->tensorParallelInfo.rankTableFile + << ", tensorParallelInfo.quantType:" << this->tensorParallelInfo.quantType + << ", tensorParallelInfo.outDataType:" << this->tensorParallelInfo.outDataType; + for (size_t i = 0; i < packQuantType.size(); ++i) { + ss << "packQuantType[" << i << "]:" << packQuantType.at(i) << std::endl; + } + for (size_t i = 0; i < linearQuantType.size(); ++i) { + ss << "linearQuantType[" << i << "]:" << linearQuantType.at(i) << std::endl; + } + for (size_t i = 0; i < linearHasBias.size(); ++i) { + ss << "linearHasBias[" << i << "]:" << linearHasBias.at(i) << std::endl; + } + for (size_t i = 0; i < linearTransposeType.size(); ++i) { + ss << "linearTransposeType[" << i << "]:" << linearTransposeType.at(i) << std::endl; + } + for (size_t i = 0; i < linearDescs.size(); ++i) { + ss << "linearDescs[" << i << "]:" << linearDescs.at(i) << std::endl; + } + for (size_t i = 0; i < isAntiOutlier.size(); ++i) { + ss << "isAntiOutlier[" << i << "]:" << isAntiOutlier.at(i) << std::endl; + } + ATB_SPEED_LOG_DEBUG(ss.str()); +} + +void LayerParam::CheckParam() +{ + if (this->hiddenSizePerAttentionHead == 0) { + std::stringstream ss; + ss << "Cannot be devided by zero. Param hiddenSizePerAttentionHead is zero!" << std::endl; + throw std::runtime_error(ss.str()); + } +} +} // namespace base +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/base/param/layer_param.h b/tests/proftest/layer_test_framework/models/base/param/layer_param.h new file mode 100644 index 00000000..c65f0d15 --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/param/layer_param.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_BASE_LAYER_PARAM_H +#define ATB_SPEED_BASE_LAYER_PARAM_H +#include +#include +#include "models/base/param/param.h" +#include "operations/fusion/linear/linear_parallel.h" + +namespace atb_speed { +namespace base { + +/// Parameters for the base layer, inherited from the `Param` class. +/// +/// In addition to the parameters defined in the Param class, +/// this class introduces additional parameters specific to the base `DecoderLayer` class. +class LayerParam : public Param { +public: + LayerParam() {}; + ~LayerParam() override {}; + void PrintParam() override; + void CheckParam() override; + + /// The layer index, starting from 0 + int layerId = 0; + /// Number of hidden layers + int numHiddenLayers = 0; + /// Information for tensor parallelism + atb_speed::common::TensorParallelInfo tensorParallelInfo; + /// Indicates the pack type and the quantization type of the qkv linear and gate up linear. + std::vector packQuantType = { + common::PackQuantType::PACK_QUANT_UNDEFINED, common::PackQuantType::PACK_QUANT_UNDEFINED + }; + /// Specifies the quantization type for the following linear module: + /// q linear, k linear, v linear, dense linear, gate linear, up linear, and down linear. + std::vector linearQuantType = { + common::LinearType::INVALID, common::LinearType::INVALID, common::LinearType::INVALID, + common::LinearType::INVALID, common::LinearType::INVALID, common::LinearType::INVALID, + common::LinearType::INVALID + }; + /// Defines the transpose type of the second matrix in the matmul operation for the following linear module: + /// q linear, k linear, v linear, dense linear, gate linear, up linear, and down linear. + std::vector linearTransposeType = {}; + /// Specifies whether the following linear module has bias: + /// qkv linear, dense linear, gateup linear and down linear. + std::vector linearHasBias = {false, false, false, false}; + /// Specifies the weight description of the following linear module: + /// qkv linear, dense linear, gateup linear and down linear. + std::vector linearDescs = { + common::LinearDesc::INVALID_DESC, common::LinearDesc::INVALID_DESC, common::LinearDesc::INVALID_DESC, + common::LinearDesc::INVALID_DESC, common::LinearDesc::INVALID_DESC, common::LinearDesc::INVALID_DESC, + common::LinearDesc::INVALID_DESC + }; + /// Specifies whether the input norm and post attention norm enable antioutlier + std::vector isAntiOutlier = {false, false}; + /// A flag indicating whether currentlayer is compressed + bool isomnicompressed = false; +}; +} // namespace base +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/base/param/mapping.cpp b/tests/proftest/layer_test_framework/models/base/param/mapping.cpp new file mode 100644 index 00000000..1694feb3 --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/param/mapping.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/utils/singleton.h" +#include "atb_speed/base/external_comm_manager.h" +#include "models/base/param/mapping.h" + +namespace atb_speed { +namespace base { + +void Mapping::ParseParam(const nlohmann::json ¶mJson) +{ + this->worldSize_ = FetchJsonParam(paramJson, "worldSize"); + this->rank_ = FetchJsonParam(paramJson, "rank"); + this->rankTableFile_ = FetchJsonParam(paramJson, "rankTableFile"); + this->localWorldSize_ = FetchJsonParam(paramJson, "localWorldSize"); + GetSingleton().SetLcclCommDomainRange( + FetchJsonParam(paramJson, "lcclCommDomainLowerBound"), + FetchJsonParam(paramJson, "lcclCommDomainUpperBound") + ); + std::map strategyKeyMap = { + {WORD_EMBED_TP, "wordEmbedTp"}, + {WORD_EMBED_DP, "wordEmbedDp"}, + {ATTN_TP, "attnTp"}, + {ATTN_DP, "attnDp"}, + {ATTN_CP, "attnCp"}, + {ATTN_INNER_SP, "attnInnerSp"}, + {ATTN_O_PROJ_TP, "attnOProjTp"}, + {ATTN_O_PROJ_DP, "attnOProjDp"}, + {MLP_TP, "mlpTp"}, + {MLP_DP, "mlpDp"}, + {MOE_TP, "moeTp"}, + {MOE_EP, "moeEp"}, + {LM_HEAD_TP, "lmHeadTp"}, + {LM_HEAD_DP, "lmHeadDp"}, + }; + for (auto it = strategyKeyMap.begin(); it != strategyKeyMap.end(); it++) { + atb_speed::common::ParallelInfo parallelInfo = atb_speed::common::ParallelInfo(); + const nlohmann::json &curParamJson = paramJson[it->second]; + parallelInfo.rank = FetchJsonParam(curParamJson, "rank"); + parallelInfo.rankIds = FetchJsonParam>(curParamJson["rankIds"], "rankIds", true); + parallelInfo.bufferSize = FetchJsonParam(curParamJson, "bufferSize"); + parallelInfo.groupId = FetchJsonParam(curParamJson, "groupId"); + this->Register(it->first, parallelInfo); + } +} + +void Mapping::InitGlobalCommDomain(std::string defaultBackend) +{ + this->defaultBackend_ = defaultBackend; + uint32_t streamId = GetSingleton().GetStreamId(); + std::vector rankIds = {}; + for (uint32_t id = 0; id < this->worldSize_; id++) { + rankIds.push_back(id); + } + std::vector fixedRankIds = rankIds; + std::string backend = atb_speed::common::InitCommBackend( + this->localWorldSize_, fixedRankIds, this->defaultBackend_); + // Create global comm + ATB_SPEED_LOG_DEBUG("External Comm Manager: InitCommDomain: init"); + GetSingleton().Init(this->worldSize_, this->rank_, + backend, this->rankTableFile_, streamId); + this->isInitialized_ = true; +} + +void Mapping::Register(ParallelType parallelType, atb_speed::common::ParallelInfo parallelInfo) +{ + this->parallelStrategies_[parallelType] = parallelInfo; +} + +const atb_speed::common::ParallelInfo Mapping::Get(ParallelType parallelType) const +{ + std::stringstream ss; + auto it = this->parallelStrategies_.find(parallelType); + if (it == this->parallelStrategies_.end()) { + ss << "Mapping: Parallel type [" << parallelType << "] is not found. " + << "Existing strategies are "; + for (auto item = this->parallelStrategies_.begin(); item != this->parallelStrategies_.end(); item++) { + ss << item->first << " "; + } + throw std::out_of_range(ss.str()); + } + + atb_speed::common::ParallelInfo parallelInfo = it->second; + std::string backend = atb_speed::common::InitCommBackend( + this->localWorldSize_, parallelInfo.rankIds, this->defaultBackend_); + parallelInfo.defaultBackend = backend; + + return parallelInfo; +} + +} // namespace base +} // namesapce atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/base/param/mapping.h b/tests/proftest/layer_test_framework/models/base/param/mapping.h new file mode 100644 index 00000000..fe315466 --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/param/mapping.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_BASE_MAPPING_H +#define ATB_SPEED_BASE_MAPPING_H +#include +#include "atb_speed/log.h" +#include "models/base/param/param_utils.h" +#include "operations/fusion/parallel_info.h" + +namespace atb_speed { +namespace base { + +/// The different realizations of expert parallelism strategies +enum class MoeExpertParallelDegree : uint32_t { + /// The expert parallelism where experts are deterministic on each of the devices + STATIC = 0, + /// The expert parallelism where experts on each device are chosen upon expected + /// "workload" of each expert to ideally even out the amount of calculation on each device + DYNAMIC, + /// The mixture of static and dynamic Ep strategies, where static Ep is applied for Prefill Stage + /// and dynamic Ep is applied for Decode Stage + DUO_GRAPH +}; + +/// The different realizations of data parallelism strategies +enum class AttnDataParallelDegree : uint32_t { + /// The data parallelism where weights are duplicated across devices in the same Dp communication group + OUTER = 0, + /// The data parallelism where weights are loaded just as if tensor parallelism were applied, yet data parallelism + /// mechanism is applied in calculation + INNER +}; + +/// Defines different types of parallelism and corresponding modules +enum ParallelType : uint32_t { + WORD_EMBED_TP = 0, + WORD_EMBED_DP, + LM_HEAD_TP, + LM_HEAD_DP, + ATTN_TP, + ATTN_DP, + ATTN_CP, + ATTN_INNER_SP, + ATTN_O_PROJ_TP, + ATTN_O_PROJ_DP, + MLP_TP, + MLP_DP, + MOE_TP, + MOE_EP, + PARALLEL_TYPE_END, +}; + +class Mapping { +public: + /// Global world size + uint32_t worldSize_; + /// An indicator that shows whether commDomains are assigned to each parallel strategy + bool isInitialized_ = false; + + /// Convert and `nlohmann::json` object to a `Mapping` object + /// \param paramJson An `nolhmann::json` object holds all the required parameters. + void ParseParam(const nlohmann::json ¶mJson); + /// Add a `ParallelInfo` strategy into `parallelStrategies_` with key `parallelType` + /// \param parallelType The key of the strategy + /// \param parallelInfo A `ParallelInfo` object that holds info of the communication group + void Register(ParallelType parallelType, atb_speed::common::ParallelInfo parallelInfo); + /// Get a `ParallelInfo` object from `parallelStrategies_` by key `parallelType` + /// \param parallelType The key of the strategy + /// \throw Throws out of range error if key `parallelType` is not in `parallelStrategies_` + /// \return a `ParallelInfo` object corresponding to the parallelism strategy of the target module + const atb_speed::common::ParallelInfo Get(ParallelType parallelType) const; + /// Initialize the communication group of each parallelism strategy + /// \param defaultBackend The communication bacekdn + /// \return A flag indicating whether the communication domain is created successfully + void InitGlobalCommDomain(std::string defaultBackend); + +private: + /// A map holds a `ParallelInfo` object and corresponding module + std::map parallelStrategies_; + /// Number of devices in the current node + uint32_t localWorldSize_ = 0; + /// Global rank + uint32_t rank_; + /// The default communication backend, currently support `hccl` and `lccl` + std::string defaultBackend_ = ""; + /// Path of the file contains devices' Ip and rank info to construct communication groups + std::string rankTableFile_ = ""; +}; + +} // namespace base +} // namespace atb_speed + + +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/base/param/model_param.cpp b/tests/proftest/layer_test_framework/models/base/param/model_param.cpp new file mode 100644 index 00000000..b322fa5e --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/param/model_param.cpp @@ -0,0 +1,364 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "models/base/param/model_param.h" + +namespace atb_speed { +namespace base { + +nlohmann::json StringToJson(const std::string ¶m) +{ + nlohmann::json paramJson; + try { + paramJson = nlohmann::json::parse(param); + } catch (const std::exception &e) { + std::stringstream ss; + ss << "parse param fail, please check param's format, error: " << e.what() << std::endl; + throw std::runtime_error(ss.str()); + } + return paramJson; +} + +void ModelParam::FromString(const std::string ¶m) +{ + nlohmann::json paramJson = StringToJson(param); + ParseParam(paramJson); + CheckParam(); + PrintParam(); +} + +void ModelParam::PrintParam() +{ + Param::PrintParam(); + ATB_SPEED_LOG_DEBUG("Model Param: isEmbeddingParallel: " << this->isEmbeddingParallel + << ", isLmHeadParallel: " << this->isLmHeadParallel + << ", lmHeadTransposeType: " << this->lmHeadTransposeType + << ", numHiddenLayers: " << this->numHiddenLayers + << ", rank: " << this->rank + << ", worldSize: " << this->worldSize + << ", backend: " << this->backend + << ", rankTableFile: " << this->rankTableFile + << ", enableDap: " << this->enableDap + << ", enableFlashComm: " << this->enableFlashComm); +} + +void ModelParam::ParseParam(const nlohmann::json ¶mJson) +{ + this->isUnpadInputs = FetchJsonParam(paramJson, "isUnpadInputs"); + this->isPrefill = FetchJsonParam(paramJson, "isPrefill"); + this->isBF16 = FetchJsonParam(paramJson, "isBF16"); + this->normEps = FetchJsonParam(paramJson, "normEps"); + this->normType = FetchJsonParam(paramJson, "normType"); + if (paramJson.contains("isEdgeHardware")) { + this->isEdgeHardware = FetchJsonParam(paramJson, "isEdgeHardware"); + } + this->numHiddenLayers = CheckNumHiddenLayersValid(FetchJsonParam(paramJson, "numHiddenLayers")); + if (paramJson.contains("skipWordEmbedding")) { + this->skipWordEmbedding = FetchJsonParam(paramJson, "skipWordEmbedding"); + } + if (paramJson.contains("positionEmbeddingType")) { + this->positionEmbeddingType = FetchJsonParam(paramJson, "positionEmbeddingType"); + } + if (paramJson.contains("enablePrefixCache")) { + this->enablePrefixCache = FetchJsonParam(paramJson, "enablePrefixCache"); + } + if (paramJson.contains("weightQuantType")) { + this->weightQuantType = FetchJsonParam(paramJson, "weightQuantType"); + } + if (paramJson.contains("enableGreedySearchOpt")) { + this->enableGreedyPostProcessing = paramJson["enableGreedySearchOpt"].get(); + } + if (paramJson.contains("attnSkipLayerSet")) { + for (auto item : paramJson["attnSkipLayerSet"]) { + attnSkipLayerSet.push_back(item.get()); + } + CheckSkipLayerSet(attnSkipLayerSet, this->numHiddenLayers); + } + if (paramJson.contains("mlpSkipLayerSet")) { + for (auto item : paramJson["mlpSkipLayerSet"]) { + mlpSkipLayerSet.push_back(item.get()); + } + CheckSkipLayerSet(mlpSkipLayerSet, this->numHiddenLayers); + } + if (paramJson.contains("enableDap")) { this->enableDap = FetchJsonParam(paramJson, "enableDap"); } + if (paramJson.contains("enableCVOverlap")) { + this->enableCVOverlap = FetchJsonParam(paramJson, "enableCVOverlap"); + } + if (paramJson.contains("enableFlashComm")) { + this->enableFlashComm = paramJson["enableFlashComm"].get(); + } + ParseNormParam(paramJson); + ParseAttentionParam(paramJson); + ParseMlpParam(paramJson); + ParseMatmulParam(paramJson); + ParseTensorParallelParam(paramJson); + ParseParallelismParam(paramJson); +} + +void ModelParam::ParseNormParam(const nlohmann::json ¶mJson) +{ + if (paramJson.contains("enableIntraLayerAddNorm")) { + this->enableIntraLayerAddNorm = FetchJsonParam(paramJson, "enableIntraLayerAddNorm"); + } + if (paramJson.contains("enableInterLayerAddNorm")) { + this->enableInterLayerAddNorm = FetchJsonParam(paramJson, "enableInterLayerAddNorm"); + } + if (paramJson.contains("isAntiOutlier")) { + for (auto item : paramJson["isAntiOutlier"]) { + this->isAntiOutlier.push_back(FetchJsonParam>(item, "isAntiOutlier", true)); + } + CheckLinearParamsSufficient(this->isAntiOutlier, this->numHiddenLayers, 2); // 2: two norm in one layer + } +} + +void ModelParam::ParseAttentionParam(const nlohmann::json ¶mJson) +{ + this->isFA = FetchJsonParam(paramJson, "isFA"); + this->numAttentionHeadsPerRank = FetchJsonParam(paramJson, "numAttentionHeadsPerRank"); + this->hiddenSizePerAttentionHead = FetchJsonParam(paramJson, "hiddenSizePerAttentionHead"); + this->numKeyValueHeadsPerRank = FetchJsonParam(paramJson, "numKeyValueHeadsPerRank"); + if (paramJson.contains("enableKvQuant")) { + this->enableKvQuant = FetchJsonParam(paramJson, "enableKvQuant"); + } + if (paramJson.contains("enableFA3")) { this->enableFA3 = FetchJsonParam(paramJson, "enableFA3"); } + if (paramJson.contains("attnBackend")) { + this->attnBackend = FetchJsonParam(paramJson, "attnBackend"); + } + if (paramJson.contains("enableSpeculate")) { + this->enableSpeculate = FetchJsonParam(paramJson, "enableSpeculate"); + } + if (paramJson.contains("enableSplitFuse")) { + this->enableSplitFuse = FetchJsonParam(paramJson, "enableSplitFuse"); + } + if (paramJson.contains("enableCompressHead")) { + this->enableCompressHead = FetchJsonParam(paramJson, "enableCompressHead"); + } + if (paramJson.contains("enableOmniAttention")) { + this->enableOmniAttention = FetchJsonParam(paramJson, "enableOmniAttention"); + if (this->enableOmniAttention) { + for (auto item : paramJson["pattern_mask"]) { + this->patternMask.push_back(item.get()); + } + } + } + if (paramJson.contains("enableRopeQuantKvcache")) { + this->enableRopeQuantKvcache = paramJson["enableRopeQuantKvcache"].get(); + } + if (paramJson.contains("useQKNorm")) { + this->useQKNorm = paramJson["useQKNorm"].get(); + } + if (paramJson.contains("enableModelConfuscation")) { + this->enableModelConfuscation = paramJson["enableModelConfuscation"].get(); + } + if (paramJson.contains("modelConfuscationFd")) { + this->modelConfuscationFd = paramJson["modelConfuscationFd"].get(); + } +} + +void ModelParam::ParseMlpParam(const nlohmann::json ¶mJson) +{ + if (paramJson.contains("enableSwiGLU")) { + this->enableSwiGLU = FetchJsonParam(paramJson, "enableSwiGLU"); + } + if (paramJson.contains("enableSwigluQuant")) { + this->enableSwigluQuant = FetchJsonParam(paramJson, "enableSwigluQuant"); + } +} + +void ModelParam::ParseMatmulParam(const nlohmann::json ¶mJson) +{ + this->lmHeadTransposeType = FetchJsonParam(paramJson, "lmHeadTransposeType"); + if (paramJson.contains("packQuantType")) { + for (auto item : paramJson["packQuantType"]) { + this->packQuantType.push_back(FetchJsonParam>(item, "packQuantType", true)); + } + CheckPackQuantParamsSufficient(this->packQuantType, this->numHiddenLayers); + } + if (paramJson.contains("linearQuantType")) { + for (auto item : paramJson["linearQuantType"]) { + this->linearQuantType.push_back(FetchJsonParam>(item, "linearQuantType", true)); + } + CheckLinearPackParamsSufficient(this->linearQuantType, this->numHiddenLayers); + } + if (paramJson.contains("linearTransposeType")) { + for (auto item : paramJson["linearTransposeType"]) { + this->linearTransposeType.push_back(FetchJsonParam>(item, "linearTransposeType", true)); + } + CheckLinearPackParamsSufficient(this->linearTransposeType, this->numHiddenLayers); + } + if (paramJson.contains("linearHasBias")) { + for (auto item : paramJson["linearHasBias"]) { + this->linearHasBias.push_back(FetchJsonParam>(item, "linearHasBias", true)); + } + CheckLinearHasBiasSufficient(this->linearHasBias, this->numHiddenLayers); + } + if (paramJson.contains("linearDescs")) { + for (auto item : paramJson["linearDescs"]) { + this->linearDescs.push_back(FetchJsonParam>(item, "linearDescs", true)); + } + CheckLinearPackParamsSufficient(this->linearDescs, this->numHiddenLayers); + } + if (paramJson.contains("enableReduceQuant")) { + this->enableReduceQuant = FetchJsonParam(paramJson, "enableReduceQuant"); + } + if (paramJson.contains("enableLora")) { + this->enableLora = FetchJsonParam(paramJson, "enableLora"); + } + if (paramJson.contains("enablePreFetchWeight")) { + this->enablePreFetchWeight = FetchJsonParam(paramJson, "enablePreFetchWeight"); + } + if (paramJson.contains("loraEnableGMM")) { + this->loraEnableGMM = FetchJsonParam(paramJson, "loraEnableGMM"); + } + if (paramJson.contains("quantGroupSize")) { + this->quantGroupSize = FetchJsonParam(paramJson, "quantGroupSize"); + } + if (paramJson.contains("matmulBackend")) { + this->matmulBackend = FetchJsonParam(paramJson, "matmulBackend"); + } +} + +void ModelParam::ParseTensorParallelParam(const nlohmann::json ¶mJson) +{ + if (paramJson.contains("isEmbeddingParallel")) { + this->isEmbeddingParallel = FetchJsonParam(paramJson, "isEmbeddingParallel"); + } + if (paramJson.contains("isLmHeadParallel")) { + this->isLmHeadParallel = FetchJsonParam(paramJson, "isLmHeadParallel"); + } + this->backend = FetchJsonParam(paramJson, "backend"); + if (paramJson.contains("mapping")) { + this->mapping.ParseParam(paramJson["mapping"]); + // prepare communication group + this->mapping.InitGlobalCommDomain(this->backend); + } + this->rank = FetchJsonParam(paramJson, "rank"); + this->worldSize = FetchJsonParam(paramJson, "worldSize"); + this->worldSize = CheckPositive(this->worldSize); + if (paramJson.contains("rankTableFile")) { + this->rankTableFile = FetchJsonParam(paramJson, "rankTableFile"); + } + if (paramJson.contains("tpRankTableFile")) { + tpRankTableFile = paramJson["tpRankTableFile"].get(); + } + if (paramJson.contains("hasPp")) { this->hasPp = paramJson["hasPp"].get(); } + if (paramJson.contains("ppGroupSize")) { this->ppGroupSize = paramJson["ppGroupSize"].get(); } + if (paramJson.contains("firstPpRank")) { this->firstPpRank = paramJson["firstPpRank"].get(); } + if (paramJson.contains("lastPpRank")) { this->lastPpRank = paramJson["lastPpRank"].get(); } + if (paramJson.contains("prevPpRank")) { this->prevPpRank = paramJson["prevPpRank"].get(); } + if (paramJson.contains("nextPpRank")) { this->nextPpRank = paramJson["nextPpRank"].get(); } + if (paramJson.contains("tpRank")) { this->tpRank = paramJson["tpRank"].get(); } + if (paramJson.contains("tpWorldSize")) { this->tpWorldSize = paramJson["tpWorldSize"].get(); } + if (paramJson.contains("tpDomain")) { this->tpDomain = paramJson["tpDomain"].get(); } +} + + +void ModelParam::ParseParallelismParam(const nlohmann::json ¶mJson) +{ + if (paramJson.contains("hasAttnTp")) { + this->hasAttnTp = atb_speed::base::FetchJsonParam(paramJson, "hasAttnTp"); + } + if (paramJson.contains("attnTpRank")) { + this->attnTpRank = atb_speed::base::FetchJsonParam(paramJson, "attnTpRank"); + } + if (paramJson.contains("attnTpSize")) { + this->attnTpSize = CheckPositive(atb_speed::base::FetchJsonParam(paramJson, "attnTpSize")); + } + if (paramJson.contains("attnTpDomain")) { + this->attnTpDomain = atb_speed::base::FetchJsonParam(paramJson, "attnTpDomain"); + } + if (paramJson.contains("hasAttnDp")) { + this->hasAttnDp = atb_speed::base::FetchJsonParam(paramJson, "hasAttnDp"); + } + if (paramJson.contains("attnDpRank")) { + this->attnDpRank = atb_speed::base::FetchJsonParam(paramJson, "attnDpRank"); + } + if (paramJson.contains("attnDpSize")) { + this->attnDpSize = CheckPositive(atb_speed::base::FetchJsonParam(paramJson, "attnDpSize")); + } + if (paramJson.contains("attnDpDomain")) { + this->attnDpDomain = atb_speed::base::FetchJsonParam(paramJson, "attnDpDomain"); + } + if (paramJson.contains("hasMlpTp")) { + this->hasMlpTp = atb_speed::base::FetchJsonParam(paramJson, "hasMlpTp"); + } + if (paramJson.contains("mlpTpRank")) { + this->mlpTpRank = atb_speed::base::FetchJsonParam(paramJson, "mlpTpRank"); + } + if (paramJson.contains("mlpTpSize")) { + this->mlpTpSize = CheckPositive(atb_speed::base::FetchJsonParam(paramJson, "mlpTpSize")); + } + if (paramJson.contains("mlpTpDomain")) { + this->mlpTpDomain = atb_speed::base::FetchJsonParam(paramJson, "mlpTpDomain"); + } + if (paramJson.contains("enableMC2")) { + this->enableMC2 = paramJson["enableMC2"].get(); + } + if (paramJson.contains("enableLcoc")) { + this->enableLcoc = FetchJsonParam(paramJson, "enableLcoc"); + } +} + +void ModelParam::CheckParam() +{ + if (this->hasPp && this->tpRank >= this->tpWorldSize) { + throw std::runtime_error("tpWorldSize must be greater than tpRank, please check."); + } + if (this->rank >= this->worldSize) { + throw std::runtime_error("worldSize must be greater than rank, please check."); + } + if (this->positionEmbeddingType != ROPE && this->positionEmbeddingType != ALIBI && \ + this->positionEmbeddingType != ABSOLUTE) { + throw std::runtime_error("positionEmbeddingType is an enumeration variable with possible values: ROPE = 0, " + "ALIBI = 1 or ABSOLUTE = 2, please check."); + } + if (this->normType != RMS_NORM && this->normType != LAYER_NORM) { + throw std::runtime_error("normType is an enumeration variable with possible values: RMS_NORM = 0 or " + "LAYER_NORM = 1, please check."); + } + if (this->attnBackend != atb_speed::common::ATB && this->attnBackend != atb_speed::common::ACLNN) { + throw std::runtime_error("attnBackend is an enumeration variable with possible values: ACLNN = 0 or " + "ATB = 1, please check."); + } + if (this->lmHeadTransposeType != atb_speed::common::TRANSPOSE_INVALID && this->lmHeadTransposeType != \ + atb_speed::common::NOT_TRANSPOSE && this->lmHeadTransposeType != atb_speed::common::TRANSPOSE) { + throw std::runtime_error("lmHeadTransposeType is an enumeration variable with possible values: " + "TRANSPOSE_INVALID = -1, NOT_TRANSPOSE = 0 or TRANSPOSE = 1, please check."); + } + auto packType = atb_speed::common::ConvertQuantTypeToPackType(this->weightQuantType); + if (packType == atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED && !this->weightQuantType.empty()) { + throw std::runtime_error( + "weightQuantType should be float, w8a8, w8a8s, w8a8sc, w8a8_dynamic, w8a16, w4a16 or an empty string."); + } + // skip attention/mlp与addNorm不兼容 + if ((this->enableIntraLayerAddNorm || this->enableInterLayerAddNorm) && \ + (attnSkipLayerSet.size() != 0 || mlpSkipLayerSet.size() != 0)) { + throw std::runtime_error("'enableIntraLayerAddNorm/enableInterLayerAddNorm' and " + "'attnSkipLayerSet/mlpSkipLayerSet' are incompatible, do not enable them at the same time, please check."); + } + // hasAttnDp 与addNorm不兼容 + if ((this->enableIntraLayerAddNorm || this->enableInterLayerAddNorm) && this->hasAttnDp) { + throw std::runtime_error("'enableIntraLayerAddNorm or enableInterLayerAddNorm' and " + "'hasAttnDp' are incompatible, do not enable them at the same time, please check."); + } + // enableDap 与 OmniAttention/enableCompressHead 不兼容 + if (this->enableDap && (this->enableOmniAttention || this->enableCompressHead)) { + throw std::runtime_error("'DAP' and 'enableOmniAttention/enableCompressHead' are incompatible, " + "do not enable them at the same time, please check."); + } +} +} // namespace base +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/base/param/model_param.h b/tests/proftest/layer_test_framework/models/base/param/model_param.h new file mode 100644 index 00000000..a1cd9239 --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/param/model_param.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_BASE_MODEL_PARAM_H +#define ATB_SPEED_BASE_MODEL_PARAM_H +#include +#include +#include "models/base/param/param.h" +#include "models/base/param/param_utils.h" + +namespace atb_speed { +namespace base { + +/// Parse string to `nlohmann::json` +/// \param param parameters in JSON string format passed from the Python side +/// \return parameters in `nlohmann::json` format +nlohmann::json StringToJson(const std::string ¶m); + +/// Parameters for the base model, inherited from the `Param` class. +/// +/// In addition to the parameters defined in the Param class, +/// this class introduces additional parameters specific to the base `DecoderModel` class. +/// Models can inherit from this class and define further parameters tailored to their specific requirements. +class ModelParam : public Param { +public: + ModelParam() {}; + ~ModelParam() override {}; + /// Parse the input JSON string to a `ModelParam` object, validate its contents, and print parsed values. + /// \param param parameters in JSON string format passed from the Python side + void FromString(const std::string ¶m); + void PrintParam() override; + void CheckParam() override; + + /// When `skipWordEmbedding` is true, input embedding is provided and the word embedding module is skipped; + /// otherwise, input token ids are used. + bool skipWordEmbedding = false; + // When `isEmbeddingParallel` is true, the embedding weights are partitioned along the hiddenSize dimension; + /// otherwise, the weights are not partitioned. + bool isEmbeddingParallel = false; + /// When `isLmHeadParallel` is true, the LmHead weights are partitioned; otherwise, the weight is not partitioned. + bool isLmHeadParallel = true; + /// A flag indicating whether the second matrix in the matmul operation within the LmHead module is transposed. + int lmHeadTransposeType = -1; + /// A flag indicating whether post processing greedy search + bool enableGreedyPostProcessing = false; + /// Whether to enable data async parallelism (DAP) for the model. + bool enableDap = false; + /// Number of hidden layers + uint32_t numHiddenLayers = 0; + // Model Parallelism + int rank = 0; + int worldSize = 1; + int localWorldSize = 1; + bool hasPp = false; + int ppGroupSize = 0; + bool firstPpRank = true; + bool lastPpRank = true; + int prevPpRank = 0; + int nextPpRank = 0; + int tpRank = 0; + int tpWorldSize = 0; + std::string tpDomain = ""; + std::string backend = "hccl"; + std::string rankTableFile = ""; + std::string tpRankTableFile = ""; + /// Indicates the pack type and the quantization type of the QKV linear and Gate UP linear for each layer. + /// The number of inner vector corresponds to `numHiddenLayers`. + /// Each inner vector contains two integer: the first one represents the pack and the quantization type + /// of the qkv linear and the second one represents the pack and the quantization type of the gate up linear. + /// The pack types the quantization types are defined in the `PackQuantType` enumerator. + std::vector> packQuantType = {}; + /// Specifies the quantization type for each linear in every layer. + /// The number of inner vector corresponds to `numHiddenLayers`. + /// Each inner vector contains seven interger, representing the quantization types of the following layers: + /// q linear, k linear, v linear, dense linear, gate linear, up linear, and down linear. + /// The quantization types are defined in the `LinearType` enumerator. + std::vector> linearQuantType = {}; + /// Defines the transpose type of the second matrix in the matmul operation for each linear in every layer. + /// The number of inner vector corresponds to `numHiddenLayers`. + /// Each inner vector contains seven interger, representing the quantization types of the following layers: + /// q linear, k linear, v linear, dense linear, gate linear, up linear, and down linear. + /// The transpose types are defined in the `TransposeType` enumerator. + std::vector> linearTransposeType = {}; + /// The indices of the layers where the attention module should be skipped. + /// The size of `attnSkipLayerSet` must be within the range in the range of [0, `numHiddenLayers`]. + /// Each element represents a layer index, which must also fall within the range [0, `numHiddenLayers`]. + std::vector attnSkipLayerSet = {}; + /// The indices of the layers where the mlp module should be skipped. + /// The size of `mlpSkipLayerSet` must be within the range in the range of [0, `numHiddenLayers`]. + /// Each element represents a layer index, which must also fall within the range [0, `numHiddenLayers`]. + std::vector mlpSkipLayerSet = {}; + /// Specifies whether linear module has bias + /// The number of inner vector corresponds to `numHiddenLayers`. + /// Each inner vector contains four boolean value, indicating whether the following linear module has bias: + /// qkv linear, dense linear, gateup linear and down linear. + std::vector> linearHasBias = {}; + std::vector> linearDescs = {}; + std::vector> isAntiOutlier = {}; + +protected: + /// Convert an `nlohmann::json` object to a `ModelParam` object. + /// \param paramJson An `nlohmann::json` object holds all the required parameters. + virtual void ParseParam(const nlohmann::json ¶mJson); + /// Converts normalization-related parameters from the `nlohmann::json` object + /// into attributes of the `ModelParam` object. + /// This function is called by `ParseParam`. + /// \param paramJson An `nlohmann::json` object holds all the required parameters. + void ParseNormParam(const nlohmann::json ¶mJson); + /// Converts attention-related parameters from the `nlohmann::json` object + /// into attributes of the `ModelParam` object. + /// This function is called by `ParseParam`. + /// \param paramJson An `nlohmann::json` object holds all the required parameters. + virtual void ParseAttentionParam(const nlohmann::json ¶mJson); + /// Converts mlp-related parameters from the `nlohmann::json` object + /// into attributes of the `ModelParam` object. + /// This function is called by `ParseParam`. + /// \param paramJson An `nlohmann::json` object holds all the required parameters. + void ParseMlpParam(const nlohmann::json ¶mJson); + /// Converts matmul-related parameters from the `nlohmann::json` object into attributes of the `ModelParam` object. + /// This function is called by `ParseParam`. + /// \param paramJson An `nlohmann::json` object holds all the required parameters. + virtual void ParseMatmulParam(const nlohmann::json ¶mJson); + /// Converts parallelism-related related parameters from the `nlohmann::json` object + /// into attributes of the `ModelParam` object. + /// This function is called by `ParseParam`. + /// \param paramJson An `nlohmann::json` object holds all the required parameters. + virtual void ParseTensorParallelParam(const nlohmann::json ¶mJson); + virtual void ParseParallelismParam(const nlohmann::json ¶mJson); +}; + +} // namespace base +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/base/param/param.h b/tests/proftest/layer_test_framework/models/base/param/param.h new file mode 100644 index 00000000..65abd19e --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/param/param.h @@ -0,0 +1,225 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_BASE_PARAM_H +#define ATB_SPEED_BASE_PARAM_H +#include "atb_speed/log.h" +#include "atb_speed/utils/check_util.h" +#include "operations/fusion/utils.h" +#include "models/base/param/mapping.h" + +namespace atb_speed { +namespace base { + +/// An enumerator specifies various types of position embeddings +enum PositionEmbeddingType : uint32_t { + /// Rotary position embedding + ROPE = 0, + /// Attention with linear biases + ALIBI, + /// Absolute position encodings + ABSOLUTE, +}; + +/// An enumerator specifies various types of layer normalization +enum NormType : uint32_t { + /// Root mean square normalization + RMS_NORM = 0, + /// Layer normalization + LAYER_NORM, +}; + +/// An enumerator represents the positional index of linear bias +enum HasBias : uint32_t { + /// the positional index of the bias in the QKV linear + QKV_HASBIAS = 0, + /// the positional index of the bias in the attention's dense linear + SELFATTENTION_HASBIAS, + /// the positional index of the bias in the gate up linear + GATEUP_HASBIAS, + /// the positional index of the bias in the mlp's down linear + DOWN_HASBIAS, +}; + +/// Common parameters shared between the model and layer classes +class Param { +public: + /// If `isFA` is true, Flash Attention is used; otherwise, Paged Attention is used + bool isFA = true; + /// A flag that indicates whether the input includes padding + bool isUnpadInputs = true; + /// A flag indicating the prefill and decode phases + bool isPrefill = false; + /// When `isBF16` is true, bfloat16 precision is used; otherwise, float16 precision is used. + bool isBF16 = false; + /// A flag indicating that an edge device is used + bool isEdgeHardware = false; + /// A flag that indicates whether the MLP module utilizes the SwiGLU fusion operation + bool enableSwiGLU = false; + /// A flag indicating whether q_norm and k_norm is enabled + bool useQKNorm = false; + /// A flag that indicates whether the shared exports module utilizes the SwiGLUQuant fusion operation + bool enableSwiGLUQuantForSharedExperts = false; + /// A flag that indicates whether low-latency computation over communication is enabled + bool enableLcoc = false; + // A flag that indicates whether mc2 is enabled + bool enableMC2 = false; + /// A flag indicating whether speculation is enabled + bool enableSpeculate = false; + /// A flag indicating whether razor attention is enabled + bool enableCompressHead = false; + /// A flag indicating whether omni attention is enabled + bool enableOmniAttention = false; + bool isomnicompressed = false; + /// A vector stores compressed info of each layer + std::vector patternMask = {}; + /// A flag indicating whether split fuse is enabled + bool enableSplitFuse = false; + /// A flag indicating whether lora is enabled + bool enableLora = false; + /// A flag indicating whether the group matmul operation is enabled; + /// it should be activated when batch inputs include multiple LoRA adapters + bool loraEnableGMM = false; + /// A flag indicating whether to use int8 quantization for the KV cache + bool enableKvQuant = false; + /// A flag indicating whether to use int8 quantization for the KV cache layer + bool enableKvQuantLayer = false; + /// A flag indicating whether int8 quantization for the KV cache has offset (i.e., asymmetric) + bool kvQuantHasOffset = true; + /// A flag indicating whether RopeQuantKvcache is enabled (i.e., asymmetric) + bool enableRopeQuantKvcache = false; + /// A flag indicating whether flash attention 3 is enabled + bool enableFA3 = false; + /// A flag indicating whether all reduce quantization is enabled. + /// It can be enabled only when the communication backend is set to "lccl". + bool enableReduceQuant = false; + /// Whether to enable inter-layer addRmsNorm fusion, default false. + bool enableInterLayerAddNorm = false; + /// Whether to enable intra-layer addRmsNorm fusion, default false. + bool enableIntraLayerAddNorm = false; + /// A flag indicating whether prefix cache is enabled + bool enablePrefixCache = false; + /// A flag indicating whether prefetch weight + bool enablePreFetchWeight = false; + /// A flag indicating whether the model use cube and vector parallel + bool enableCVOverlap = false; + /// A flag indicating whether the attention module is skipped in the layer + bool isAttnSkipLayer = false; + /// A flag indicating whether the mlp module is skipped in the layer + bool isMlpSkipLayer = false; + /// A flag indicating whether to use swigluQuant + bool enableSwigluQuant = false; + /// A flag indicating whether open flashcomm1.0 + bool enableFlashComm = false; + /// A flag indicating whether to use pmcc model obfuscation + bool enableModelConfuscation = false; + /// A handle used by pmcc model obfuscation + int32_t modelConfuscationFd = 0; + /// The backend of the attention module; refer to `OpBackend` for the supported values + atb_speed::common::OpBackend attnBackend = atb_speed::common::OpBackend::ATB; + /// The backend of the matmul module; refer to `OpBackend` for the supported values + atb_speed::common::OpBackend matmulBackend = atb_speed::common::OpBackend::ATB; + /// The type of the position embedding; refer to `PositionEmbeddingType` for the supported values + PositionEmbeddingType positionEmbeddingType = PositionEmbeddingType::ROPE; + /// The epsilon value used for normalization + float normEps = 0; + /// The type of the normalization; refer to `NormType` for the supported values + NormType normType = NormType::RMS_NORM; + /// The group size used for dequantizing the weight tensor in the per-group quantization approach + uint32_t quantGroupSize = 0; + /// Number of attention heads per rank + uint32_t numAttentionHeadsPerRank = 0; + /// The dimension of the hidden representations for each attention head + uint32_t hiddenSizePerAttentionHead = 0; + /// If `numKeyValueHeadsPerRank` equals to `numAttentionHeadsPerRank`, the model will use Multi Head Attention; + /// otherwise, Grouped Query Attention is used + uint32_t numKeyValueHeadsPerRank = 0; + /// The quantization type applied to the model + std::string weightQuantType = ""; + // Model Parallelism + Mapping mapping; + std::string backend = "hccl"; + bool hasAttnTp = false; + int attnTpRank = 0; + int attnTpSize = 1; + std::string attnTpDomain = ""; + std::string attnTpRankTableFile = ""; + std::string attnTpBackend = ""; + bool hasAttnDp = false; + int attnDpRank = 0; + int attnDpSize = 1; + std::string attnDpDomain = ""; + std::string attnDpRankTableFile = ""; + std::string attnDpBackend = ""; + + bool hasMlpEp = false; + int mlpEpRank = 0; + int mlpEpSize = 1; + std::string mlpEpDomain = ""; + std::string mlpEpRankTableFile = ""; + std::string mlpEpBackend = ""; + bool hasMlpTp = false; + int mlpTpRank = 0; + int mlpTpSize = 1; + std::string mlpTpDomain = ""; + std::string mlpTpRankTableFile = ""; + std::string mlpTpBackend = ""; + + Param() {}; + virtual ~Param() {}; + + /// A member function that outputs the values of all parameters + virtual void PrintParam() + { + ATB_SPEED_LOG_DEBUG("Param: " << "isFA: " << isFA + << ", isUnpadInputs: " << isUnpadInputs + << ", isPrefill: " << isPrefill + << ", isBF16: " << isBF16 + << ", isEdgeHardware: " << isEdgeHardware + << ", enableSwiGLU: " << enableSwiGLU + << ", enableLcoc: " << enableLcoc + << ", enableSpeculate: " << enableSpeculate + << ", enableCompressHead: " << enableCompressHead + << ", enableOmniAttention: " << enableOmniAttention + << ", enableSplitFuse: " << enableSplitFuse + << ", enableLora: " << enableLora + << ", useQKNorm: " << useQKNorm + << ", loraEnableGMM: " << loraEnableGMM + << ", enableKvQuant: " << enableKvQuant + << ", enableReduceQuant: " << enableReduceQuant + << ", enableInterLayerAddNorm: " << enableInterLayerAddNorm + << ", enableIntraLayerAddNorm: " << enableIntraLayerAddNorm + << ", enablePrefixCache: " << enablePrefixCache + << ", attnBackend: " << attnBackend + << ", positionEmbeddingType: " << positionEmbeddingType + << ", normType: " << normType + << ", normEps: " << normEps + << ", quantGroupSize: " << quantGroupSize + << ", numAttentionHeadsPerRank: " << numAttentionHeadsPerRank + << ", hiddenSizePerAttentionHead: " << hiddenSizePerAttentionHead + << ", numKeyValueHeadsPerRank: " << numKeyValueHeadsPerRank + << ", enableMC2: " << enableMC2 + << ", weightQuantType: " << weightQuantType + << ", enableSwigluQuant" << enableSwigluQuant + << ", matmulBackend" << matmulBackend); + } + /// A member function that checks and validates the values of all parameters + virtual void CheckParam() {}; +}; +} // namespace base +} // namespace atb_speed + + +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/base/param/param_utils.h b/tests/proftest/layer_test_framework/models/base/param/param_utils.h new file mode 100644 index 00000000..37f40371 --- /dev/null +++ b/tests/proftest/layer_test_framework/models/base/param/param_utils.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PARAM_UTILS_H +#define ATB_SPEED_PARAM_UTILS_H +#include +#include "atb_speed/log.h" + +namespace atb_speed { +namespace base { + +/// A template function to verify the type of the parameter. +/// It will call `nlohmann::json`'s `get` method to extract the JSON value and convert to the target type. +/// \tparam T The acceptable data types are int, bool, float, string, uint32_t, std::vector, std::vector. +/// \param paramJson An `nlohmann::json` object holds all the required parameters. +/// \param key The key used to retrieve the value from the `nlohmann::json` object. +/// \param isVector A flag indicates whether the target value is in the vector format. +/// \return The extracted value after type conversion. +template +T FetchJsonParam(const nlohmann::json& paramJson, const std::string& key, bool isVector = false) +{ + try { + if (isVector) { + return paramJson.get(); + } else { + return paramJson.at(key).get(); + } + } catch (const std::exception& e) { + std::stringstream ss; + ss << "Failed to parse parameter " << key << ": " << e.what() << ". Please check the type of param."; + ATB_SPEED_LOG_ERROR(ss.str(), ATB_MODELS_MODEL_PARAM_JSON_INVALID); + throw std::runtime_error(ss.str()); + } +} + +} // namespace base +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/bloom/layer/bloom_decoder_layer.cpp b/tests/proftest/layer_test_framework/models/bloom/layer/bloom_decoder_layer.cpp new file mode 100644 index 00000000..823b4249 --- /dev/null +++ b/tests/proftest/layer_test_framework/models/bloom/layer/bloom_decoder_layer.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "models/bloom/layer/bloom_decoder_layer.h" + +namespace atb_speed { +namespace bloom { + +BloomDecoderLayer::BloomDecoderLayer( + const atb_speed::base::LayerParam ¶m) : atb_speed::base::DecoderLayer( + static_cast(param)) +{ + this->param = param; + this->param.CheckParam(); +}; + + +void BloomDecoderLayer::SetFusionAttentionLinearParam( + atb_speed::common::FusionAttentionParam &fusionAttentionParam) +{ + DecoderLayer::SetFusionAttentionLinearParam(fusionAttentionParam); + fusionAttentionParam.qkvHasBias = true; + fusionAttentionParam.splitWithStride = true; +} + + +void BloomDecoderLayer::SetMlpParam(atb_speed::common::MlpParam &mlpParam) +{ + DecoderLayer::SetMlpParam(mlpParam); + mlpParam.normHasBias = true; + mlpParam.mlpPackType = atb_speed::common::GetMlpPackType(this->param.packQuantType.at(1), true); + mlpParam.activationParam.geluMode = atb::infer::ActivationParam::TANH_MODE; + mlpParam.activationParam.activationType = atb::infer::ActivationType::ACTIVATION_GELU; +} + +} // namespace bloom +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/bloom/layer/bloom_decoder_layer.h b/tests/proftest/layer_test_framework/models/bloom/layer/bloom_decoder_layer.h new file mode 100644 index 00000000..a9feeb3e --- /dev/null +++ b/tests/proftest/layer_test_framework/models/bloom/layer/bloom_decoder_layer.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_BLOOM_DECODER_LAYER_H +#define ATB_SPEED_MODELS_BLOOM_DECODER_LAYER_H + +#include "atb/atb_infer.h" +#include "models/base/param/layer_param.h" +#include "models/base/layer/decoder_layer.h" + + +namespace atb_speed { +namespace bloom { + +class BloomDecoderLayer : public atb_speed::base::DecoderLayer { +public: + explicit BloomDecoderLayer(const atb_speed::base::LayerParam ¶m); + ~BloomDecoderLayer() override {}; + +protected: + void SetFusionAttentionLinearParam( + atb_speed::common::FusionAttentionParam &fusionAttentionParam) override; + void SetMlpParam(atb_speed::common::MlpParam &mlpParam) override; + + atb_speed::base::LayerParam param; +}; + +} // namespace bloom +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/models/bloom/model/bloom_decoder_model.cpp b/tests/proftest/layer_test_framework/models/bloom/model/bloom_decoder_model.cpp new file mode 100644 index 00000000..4bad2f24 --- /dev/null +++ b/tests/proftest/layer_test_framework/models/bloom/model/bloom_decoder_model.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb/atb_infer.h" +#include "models/bloom/model/bloom_decoder_model.h" + +namespace atb_speed { +namespace bloom { + +BloomDecoderModel::BloomDecoderModel(const std::string ¶m) : atb_speed::base::DecoderModel(param) +{ + this->param.FromString(param); + this->weightCountWordEmbedding = 3; // 3: wordembedding weight, first norm weight, first norm bias + this->weightCountFinalNorm = 2; // 2: final nrom weight, final norm bias + this->inTensorCandidates = { + {"default", { + "input_ids", "positional_ids", "cosine_table", "sine_table", "attention_mask", + "block_tables", "slots", "kv_cache_idx", "token_offset", "place_holder", "seq_len", "logits_indices"} + }, + }; + this->internalTensorCandidates = { + {"default", {"hidden_states"}}, + }; +} + + +atb::Status BloomDecoderModel::AddOperationToGraph() +{ + CHECK_OPERATION_STATUS_RETURN(this->AddWordEmbedding()); + CHECK_OPERATION_STATUS_RETURN(this->AddFirstNorm()); + CHECK_OPERATION_STATUS_RETURN(this->AddLayer()); + CHECK_OPERATION_STATUS_RETURN(this->AddFinalNorm()); + CHECK_OPERATION_STATUS_RETURN(this->AddLmhead()); + return atb::NO_ERROR; +} + + +atb::Status BloomDecoderModel::AddFirstNorm() +{ + atb::Operation *op = nullptr; + + atb_speed::Model::Node firstNormNode; + atb::infer::LayerNormParam firstNormParam; + this->SetFinalNormParam(firstNormParam); // first/final set param 可以复用 + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(firstNormParam, &op)); + firstNormNode.operation.reset(op); + firstNormNode.inTensors = { + &graph_.internalTensors.at(atb_speed::common::GetTensorIdx(this->internalTensorMap, "hidden_states")), + &graph_.weightTensors.at(1), + &graph_.weightTensors.at(2) + }; + firstNormNode.outTensors = {firstNormNode.inTensors.at(0)}; // 输出原地写在输入上 + graph_.nodes.push_back(firstNormNode); + + return atb::NO_ERROR; +} + + +atb::Status BloomDecoderModel::CreateLayerOperation(atb::Operation **op, uint32_t layerId) +{ + atb_speed::base::LayerParam layerParam; + this->SetLayerParam(layerParam, layerId); + BloomDecoderLayer decoderLayer(layerParam); + CHECK_OPERATION_STATUS_RETURN(decoderLayer.BuildGraph(op)); + return atb::NO_ERROR; +} + +} // namespace bloom +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/models/bloom/model/bloom_decoder_model.h b/tests/proftest/layer_test_framework/models/bloom/model/bloom_decoder_model.h new file mode 100644 index 00000000..2d2ee7ba --- /dev/null +++ b/tests/proftest/layer_test_framework/models/bloom/model/bloom_decoder_model.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_BLOOM_DECODER_MODEL_H +#define ATB_SPEED_MODELS_BLOOM_DECODER_MODEL_H + + +#include "atb_speed/base/model.h" +#include "models/base/param/model_param.h" +#include "atb_speed/utils/model_factory.h" +#include "models/base/model/decoder_model.h" +#include "models/bloom/layer/bloom_decoder_layer.h" + +namespace atb_speed { +namespace bloom { + +class BloomDecoderModel : public atb_speed::base::DecoderModel { +public: + explicit BloomDecoderModel(const std::string ¶m); + +protected: + atb::Status CreateLayerOperation(atb::Operation **op, uint32_t layerId) override; + atb::Status AddOperationToGraph() override; + atb::Status AddFirstNorm(); + + atb_speed::base::ModelParam param; +}; + +REGISTER_MODEL(bloom, BloomDecoderModel); + +} // namespace bloom +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_global_cache.cpp b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_global_cache.cpp new file mode 100644 index 00000000..f9727edd --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_global_cache.cpp @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "atb_speed/log.h" +#include "atb_speed/utils/operation_util.h" +#include "operations/aclnn/utils/utils.h" +#include "acl_nn_global_cache.h" + +namespace atb_speed { +namespace common { + +AclNNGlobalCache::AclNNGlobalCache() +{ + const char *envStr = std::getenv("ATB_ACLNN_CACHE_GLOABL_COUNT"); + uint64_t globalCacheCountMax = DEFAULT_ACLNN_GLOBAL_CACHE_SIZE; + if (envStr != nullptr) { + globalCacheCountMax = static_cast(strtol(envStr, nullptr, DECIMAL)); + } + envStr = std::getenv("MINDIE_ACLNN_CACHE_GLOBAL_COUNT"); + if (envStr != nullptr) { + globalCacheCountMax = static_cast(strtol(envStr, nullptr, DECIMAL)); + } + + this->globalCacheCountMax_ = globalCacheCountMax; + if (this->globalCacheCountMax_ >= 100) { // 100: threshold + std::stringstream ss; + ss << "The size of AclNN operations' global cache should be less than 100." << std::endl; + throw std::runtime_error(ss.str()); + } +} + +std::shared_ptr AclNNGlobalCache::GetGlobalCache(std::string opName, atb::VariantPack variantPack) +{ + // 获取Op对应的Global Cache列表 + std::map>>::iterator it = \ + this->aclnnGlobalCache_.find(opName); + if (it == this->aclnnGlobalCache_.end()) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Op name[" << opName << "] not found in AclNNGlobalCache"); + return nullptr; + } + std::vector> &opGlobalCacheList = it->second; + + // 在Global Cache列表中基于variantPack找到匹配的Cache + for (size_t i = 0; i < opGlobalCacheList.size(); i++) { + if (opGlobalCacheList[i] == nullptr) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Global Cache index " << i << " is nullptr"); + continue; + } + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Global Cache index " << i << " call IsVariankPackEqual"); + if (opGlobalCacheList[i]->executorRepeatable && \ + IsVariankPackEqual(opGlobalCacheList[i]->aclnnVariantPack, variantPack)) { + // Global Cache命中 + return opGlobalCacheList[i]; + } + } + + return nullptr; +} + +atb::Status AclNNGlobalCache::UpdateGlobalCache(std::string opName, std::shared_ptr cache) +{ + // 若Local Cache中Executor不可复用,不更新Global Cache + if (!cache->executorRepeatable) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Op name[" << opName << "] not repeatable, do not update global cache"); + return atb::NO_ERROR; + } + + // Check Global Cache Size + if (this->globalCacheCountMax_ == 0) { + return atb::NO_ERROR; + } + + // 获取Op对应的Global Cache列表 + std::map>>::iterator it = \ + this->aclnnGlobalCache_.find(opName); + if (it == this->aclnnGlobalCache_.end()) { + // 不存在opName对应的Cache列表 + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Op name[" << opName << "] not found in AclNNGlobalCache, add one"); + this->aclnnGlobalCache_[opName] = {cache}; + return atb::NO_ERROR; + } + std::vector> &opGlobalCacheList = it->second; + + // Cache未已满 + if (opGlobalCacheList.size() < this->globalCacheCountMax_) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Op name[" << opName << "] global cache is not full, add one"); + opGlobalCacheList.push_back(cache); + return atb::NO_ERROR; + } + + // Cache已满 + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Op name[" + << opName << "] global cache is full, update index " << nextUpdateIndex_); + opGlobalCacheList[nextUpdateIndex_] = cache; + CHECK_PARAM_NE(globalCacheCountMax_, 0); + nextUpdateIndex_ = (nextUpdateIndex_ + 1) % globalCacheCountMax_; + return atb::NO_ERROR; +} + +std::string AclNNGlobalCache::PrintGlobalCache() +{ + std::stringstream ss; + ss << "Plugin Op Cache: Global Cache Summary "; + std::map>>::iterator it; + for (it = this->aclnnGlobalCache_.begin(); it != this->aclnnGlobalCache_.end(); it++) { + ss << "Op name[" << it->first << "] "; + std::vector> &opGlobalCacheList = it->second; + for (size_t i = 0; i < opGlobalCacheList.size(); i++) { + ss << "Cache Addr[" << opGlobalCacheList[i].get() << "] "; + } + } + return ss.str(); +} + +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_global_cache.h b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_global_cache.h new file mode 100644 index 00000000..e4cddc39 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_global_cache.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_NN_OPERATION_GLOBAL_CACHE_H +#define ATB_SPEED_PLUGIN_ACLNN_NN_OPERATION_GLOBAL_CACHE_H + +#include +#include +#include +#include +#include "acl_nn_operation_cache.h" +#include "acl_nn_operation.h" + +namespace atb_speed { +namespace common { + +/// Maximum number of objects stored in the global cache by default +const uint16_t DEFAULT_ACLNN_GLOBAL_CACHE_SIZE = 16; +constexpr int32_t DECIMAL = 10; /// Decimal base + +/// A class that manages global cache. +/// +/// This class keeps a private container to manage `AclNNOpCache` objects that may be shared between operations. +/// It provides get and update methods for retrieving and modifying the private container's state. +/// There is also a print function for debugging. +class AclNNGlobalCache { +public: + /// The Class constructor. + /// + /// Update `globalCacheCountMax_` with the environment variable `ATB_ACLNN_CACHE_GLOABL_COUNT` + /// or `MINDIE_ACLNN_CACHE_GLOBAL_COUNT`. (`ATB_ACLNN_CACHE_GLOABL_COUNT` will be depreciated.) + /// + /// throw runtime_error if `globalCacheCountMax_` is larger than 100. + explicit AclNNGlobalCache(); + /// Retrieve the cache object on a cache hit. + /// + /// A Cache is hit if the `variantPack` is the same, except for tensors' device data. + /// + /// \param opName An operations's name. + /// \param variantPack Information about input and output tensors of an ATB operation. + /// \return Return a pointer to an `AclNNOpCache` object on a cache hit; otherwise, returns a nullptr. + std::shared_ptr GetGlobalCache(std::string opName, atb::VariantPack variantPack); + /// Add or replace an cache object. + /// + /// Locate the global cache list for the current operation using `opName`. + /// Add the `cache` at the index specifidex by `nextUpdateIndex_`. + /// If the slot already contains a cache object, replace it. + /// Cache is not added if it's executor is not repeatable. + /// + /// \param opName An operations's name. + /// \param cache The cache to be added to the `aclnnGlobalCache_` container. + /// \return A status code that indicates whether the update operation was successful. + atb::Status UpdateGlobalCache(std::string opName, std::shared_ptr cache); + /// Print a summary of the objects stored in the `aclnnGlobalCache_`. + /// + /// The operation's name and the corresponding global cache address are printed. + /// + /// \return Cache info. + std::string PrintGlobalCache(); + +private: + /// An index maintains a record of the next available cache slot + int nextUpdateIndex_ = 0; + /// Maximum number of objects stored in the global cache + uint16_t globalCacheCountMax_ = 16; + /// A map stores `AclNNOpCache` objects. + /// + /// Key is an operation's name. Value is a vector of pointers to `AclNNOpCache` object. + /// Cache is not shared between different types of operations. + std::map>> aclnnGlobalCache_; +}; + +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation.cpp new file mode 100644 index 00000000..b3963413 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation.cpp @@ -0,0 +1,255 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "atb_speed/utils/statistic.h" +#include "operations/aclnn/utils/utils.h" +#include "atb_speed/utils/singleton.h" +#include "executor_manager.h" +#include "acl_nn_global_cache.h" +#include "acl_nn_operation.h" + +namespace atb_speed { +namespace common { + +AclNNOperation::AclNNOperation(const std::string &opName) : opName_(opName) +{ + this->aclnnOpCache_ = std::make_shared(); +} + +AclNNOperation::~AclNNOperation() +{ + ATB_SPEED_LOG_DEBUG("AclNNOperation deconstructor"); + this->DestroyOperation(); +} + +std::string AclNNOperation::GetName() const { return this->opName_; } + +void AclNNOperation::DestroyOperation() const +{ + this->aclnnOpCache_->Destroy(); +} + +atb::Status AclNNOperation::Setup(const atb::VariantPack &variantPack, uint64_t &workspaceSize, atb::Context *context) +{ + ATB_SPEED_LOG_DEBUG(this->opName_ << " setup start"); + + // 1. 检查Context是否为空 + if (context == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " setup context is null"); + return atb::ERROR_INVALID_PARAM; + } + + // 2. 获取Executor和Workspace + int ret = UpdateAclNNOpCache(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " call UpdateAclNNOpCache, error:" << ret); + this->aclnnOpCache_->Destroy(); + return ret; + } + + // 3. 更新传入的workspaceSize + workspaceSize = this->aclnnOpCache_->workspaceSize; + + ATB_SPEED_LOG_DEBUG(GetSingleton().PrintGlobalCache()); + ATB_SPEED_LOG_DEBUG(GetSingleton().PrintExecutorCount()); + return atb::NO_ERROR; +} + +atb::Status AclNNOperation::UpdateAclNNOpCache(const atb::VariantPack &variantPack) +{ + // 此方法会准备好Execute时所需的Executor和workspace + // 前提条件:GlobalCache中的executor要保证LocalCache里面一定也要有引用;仅对LocalCache进行释放 + + // 1. 查看Local Cache中Executor是否可以复用 + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Local Cache call IsVariankPackEqual"); + if (this->aclnnOpCache_->executorRepeatable && \ + IsVariankPackEqual(this->aclnnOpCache_->aclnnVariantPack, variantPack)) { + // Local Cache命中 + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Op name[" << this->opName_ << "] Op addr[" << + (this) << "] Cache addr[" << this->aclnnOpCache_.get() << "] Executor addr[" << + this->aclnnOpCache_->aclExecutor << "] Local Cache Hit"); + return atb::NO_ERROR; + } + + // 2. 查看Global Cache中Executor是否可以复用 + std::shared_ptr globalCache = \ + GetSingleton().GetGlobalCache(this->opName_, variantPack); + if (globalCache != nullptr) { + // Global Cache命中 + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Op name[" << this->opName_ << "] Op addr[" << (this) << "] Cache addr[" + << globalCache.get() << "] Executor addr[" << globalCache->aclExecutor << "] Global Cache Hit"); + // 2.1 释放旧的Local Cache + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: destroy local cache before switching to global cache"); + this->aclnnOpCache_->Destroy(); + // 2.2 更新Local Cache + this->aclnnOpCache_ = globalCache; + // 2.3 更新ExecutorManager + int count = GetSingleton().IncreaseReference(this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Op name[" << this->opName_ << "] Executor addr[" << + this->aclnnOpCache_->aclExecutor << "] count update to " << count); + return atb::NO_ERROR; + } + + // 3. Local Cache和Global Cache都未命中 + // 3.1 释放Local Cache + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: destroy local cache before create a new one"); + this->aclnnOpCache_->Destroy(); + // 3.2 根据variantPack,更新aclnnOpCache_,获取WorkSpace和Executor + this->aclnnOpCache_ = std::make_shared(); + int ret = CreateAclNNOpCache(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " call CreateAclNNOpCache fail, error:" << ret); + return ret; + } + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Op name[" << this->opName_ << "] Op addr[" << + (this) << "] Cache addr[" << this->aclnnOpCache_.get() << "] Executor addr[" << + this->aclnnOpCache_->aclExecutor << "] create Local Cache"); + // 3.3 更新ExecutorManager,新增Executor,count为1 + int count = GetSingleton().IncreaseReference(this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Op name[" << this->opName_ << "] increase Executor addr[" << + this->aclnnOpCache_->aclExecutor << "] count update to " << count); + + // 3.4 更新Global Cache(旧的Global Cache直接替换指针就行) + GetSingleton().UpdateGlobalCache(this->opName_, this->aclnnOpCache_); + + return atb::NO_ERROR; +} + +atb::Status AclNNOperation::CreateAclNNOpCache(const atb::VariantPack &variantPack) +{ + atb::Status ret = CreateAclNNVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " call CreateAclNNVariantPack fail, error:" << ret); + return atb::ERROR_CANN_ERROR; + } + + ret = SetAclNNWorkspaceExecutor(); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " call SetAclNNWorkspaceExecutor fail, error:" << ret); + return atb::ERROR_CANN_ERROR; + } + + // 若此时Local Cache为空 + if (this->aclnnOpCache_ == nullptr) { + ATB_SPEED_LOG_ERROR("Plugin Op Cache: Op name[" << this->opName_ << "] cache is nullptr after " << + "initialization, please check."); + return atb::ERROR_INTERNAL_ERROR; + } + + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: create Executor addr[" << this->aclnnOpCache_->aclExecutor << "]"); + + // 设置Local Cache中的aclExecutor为可复用,设置成功返回0,否则返回其他值 + ret = aclSetAclOpExecutorRepeatable(this->aclnnOpCache_->aclExecutor); + if (ret != 0) { + // 设置算子可复用失败,标记Local Cache中executor不可复用 + ATB_SPEED_LOG_WARN(this->opName_ << " call aclSetAclOpExecutorRepeatable fail: " << ret); + this->aclnnOpCache_->executorRepeatable = false; + } else { + // 设置算子可复用成功,标记Local Cache中executor可复用 + this->aclnnOpCache_->executorRepeatable = true; + } + + return atb::NO_ERROR; +} + +atb::Status AclNNOperation::Execute(const atb::VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize, + atb::Context *context) +{ + ATB_SPEED_LOG_DEBUG(this->opName_ << " execute start"); + if (!context) { + ATB_SPEED_LOG_ERROR(this->opName_ << " execute fail, context param is null. Enable log: " + << "export ASDOPS_LOG_LEVEL=ERROR, export ASDOPS_LOG_TO_STDOUT=1 to find the first error. " + << "For more details, see the MindIE official document." << std::endl, ATB_MODELS_EXECUTION_FAILURE); + return atb::ERROR_INVALID_PARAM; + } + + aclrtStream stream = GetExecuteStream(context); + if (!stream) { + ATB_SPEED_LOG_ERROR(this->opName_ << " execute fail, execute stream in context is null. " + << "Enable log: export ASDOPS_LOG_LEVEL=ERROR, export ASDOPS_LOG_TO_STDOUT=1 to find the first error. " + << "For more details, see the MindIE official document." << std::endl, ATB_MODELS_EXECUTION_FAILURE); + return atb::ERROR_INVALID_PARAM; + } + + // 更新数据传入的地址 + int ret = this->aclnnOpCache_->UpdateAclNNVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " call UpdateAclNNVariantPack fail, error:" << ret); + return atb::ERROR_CANN_ERROR; + } + + ATB_SPEED_LOG_DEBUG("Input workspaceSize " << workspaceSize << " localCache workspaceSize " << + this->aclnnOpCache_->workspaceSize); + ret = ExecuteAclNNOp(workspace, stream); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " call ExecuteAclNNOp fail, error:" << ret); + return atb::ERROR_CANN_ERROR; + } + + ATB_SPEED_LOG_DEBUG(this->opName_ << " execute end"); + + return atb::NO_ERROR; +} + +atb::Status AclNNOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(this->opName_ << " CreateAclNNVariantPack start"); + atb::Status ret = 0; + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + return ret; + } + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +atb::Status AclNNOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(variantPack.inTensors.size()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + aclnnVariantPack.aclInTensors[i] = CreateTensor(variantPack.inTensors.at(i), i); + if (aclnnVariantPack.aclInTensors[i]->tensor == nullptr) { + return atb::ERROR_INTERNAL_ERROR; + } + } + return atb::NO_ERROR; +} + +atb::Status AclNNOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(variantPack.outTensors.size()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + aclnnVariantPack.aclOutTensors[i] = CreateTensor(variantPack.outTensors.at(i), i); + if (aclnnVariantPack.aclOutTensors[i]->tensor == nullptr) { + return atb::ERROR_INTERNAL_ERROR; + } + } + return atb::NO_ERROR; +} + +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation.h new file mode 100644 index 00000000..62f3d15f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation.h @@ -0,0 +1,106 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_NN_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_NN_OPERATION_H +#include +#include "acl_nn_operation_cache.h" + +namespace atb_speed { +namespace common { + +/// An class that inherited from `atb::Operation` class. An `atb::Operation` class defines a series of +/// interfaces required for operation preparation and execution. +class AclNNOperation : public atb::OperationInfra { +public: + /// Class constructor. + /// + /// Initialize an `AclNNOpCache` pointer for the object's local cache (`aclnnOpCache_`) and set `opName`. + /// \param opName The name of the AclNN operation. + explicit AclNNOperation(const std::string &opName); + ~AclNNOperation() override; + /// Return the AclNN operation's name. + /// \return The object's `opName`. + std::string GetName() const override; + /// Preparations before operation execution. + /// + /// This function calls `UpdateAclNNOpCache` to update `aclnnOpCache_` + /// and calculate the memory space that needs to be allocated during the operation execution process. + /// \param variantPack Operation's input and output tensor info. + /// \param workspaceSize The size of the work space. + /// \param context The context in which operation's preparation is performed. + /// \return A status code that indicates whether the setup process was successful. + atb::Status Setup(const atb::VariantPack &variantPack, uint64_t &workspaceSize, atb::Context *context) override; + /// Operation execution process. + /// + /// Call `GetExecuteStream` from `context`. Call `UpdateAclNNVariantPack` to update tensor's device data. + /// Execute the operation. + /// \param variantPack Operation's input and output tensor info. + /// \param workspace A pointer the memory address allocated by the operation. + /// \param workspaceSize The size of the work space. + /// \param context The context in which operation's preparation is performed. + /// \return A status code that indicates whether the execute process was successful. + atb::Status Execute(const atb::VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize, + atb::Context *context) override; + /// Release all occupied resources, particularly those stored in `aclnnOpCache_`. + void DestroyOperation() const; + +protected: + /// Create the operation's local cache (`aclnnOpCache_`). + /// + /// Create the operation's input tensor and output tensor by calling `CreateAclNNVariantPack`. + /// Call `SetAclNNWorkspaceExecutor` to get work space size and `aclOpExecutor`. + /// Call `aclSetAclOpExecutorRepeatable` to make `aclOpExecutor` reusable. + /// \param variantPack Operation's input and output tensor info passed from ATB framework. + /// \return A status code that indicates whether `aclnnOpCache_` was successfully created. + atb::Status CreateAclNNOpCache(const atb::VariantPack &variantPack); + /// Verify if the local cache or global cache is hit. If neither is hit, create a new instance + /// by calling `CreateAclNNOpCache`, then update both the `ExecutorManager` and `AclNNGlobalCache`. + /// \param variantPack Operation's input and output tensor info. + /// \return A status code that indicates whether `aclnnOpCache_` was successfully updated. + atb::Status UpdateAclNNOpCache(const atb::VariantPack &variantPack); + /// Prepare the operation's input tensors and output tensors. + /// + /// This function calls `CreateAclNNInTensorVariantPack` and `CreateAclNNOutTensorVariantPack`. + /// \param variantPack An `atb::VariantPack` object containing tensor info passed through ATB framework. + /// \return A status code that indicates whether variantPack was created successfully. + virtual atb::Status CreateAclNNVariantPack(const atb::VariantPack &variantPack); + /// Prepare the operation's input tensors. + /// + /// \param variantPack An `atb::VariantPack` object containing tensor info passed through ATB framework. + /// \return A status code that indicates whether variantPack was created successfully. + virtual atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack); + /// Prepare the operation's output tensors. + /// + /// \param variantPack An `atb::VariantPack` object containing tensor info passed through ATB framework. + /// \return A status code that indicates whether variantPack was created successfully. + virtual atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack); + /// Call AclNN operation's first phase API to get work space size and `aclOpExecutor`. + /// + /// \return The return value of AclNN's first phase API. + virtual int SetAclNNWorkspaceExecutor() = 0; + /// Call AclNN operation's second phase API to execute the operation. + /// + /// \return The return value of AclNN's second phase API. + virtual int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) = 0; + + /// An `AclNNOpCache` object that can be reused within the current operation object. + std::shared_ptr aclnnOpCache_ = nullptr; + /// A human identifiable name for the operation's name. + std::string opName_; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation_cache.cpp b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation_cache.cpp new file mode 100644 index 00000000..4c9b2996 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation_cache.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/log.h" +#include "atb_speed/utils/singleton.h" +#include "executor_manager.h" +#include "acl_nn_operation_cache.h" + +namespace atb_speed { +namespace common { + +void AclNNOpCache::Destroy() +{ + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: AclNNOpCache addr [" << (this) << "]destroy"); + if (this->aclExecutor == nullptr) { return; } + + // ExecutorManager中的引用减1 + int count = GetSingleton().DecreaseReference(this->aclExecutor); + if (count != 0) { return; } // 如果executor的引用不为0,则不删除executor及其对应的aclTensor + + // 如果aclExecutor存在且引用为0,则destroy + int ret = -1; + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: destroy Executor addr[" << this->aclExecutor << "]"); + if (this->executorRepeatable) { + // 如果executor可复用,进行destroy;否则不destroy,避免对aclExecutor的重复释放 + ret = aclDestroyAclOpExecutor(this->aclExecutor); + if (ret != 0) { + ATB_SPEED_LOG_ERROR("Plugin Op Cache: destroy Executor failed."); + } + } + this->aclExecutor = nullptr; + + // 清空用于构造aclExecutor而创建的结构体 + for (size_t i = 0; i < this->aclnnVariantPack.aclInTensors.size(); ++i) { + if (this->aclnnVariantPack.aclInTensors[i]->tensorListidx == AclNNTensor::notInTensorList) { + ret = aclDestroyTensor(this->aclnnVariantPack.aclInTensors[i]->tensor); + if (ret != 0) { ATB_SPEED_LOG_ERROR("Plugin Op Cache: destroy aclInTensors " << i << " failed."); } + } + ret = aclDestroyIntArray(this->aclnnVariantPack.aclInTensors[i]->intArrayHostData.intArray); + if (ret != 0) { + ATB_SPEED_LOG_ERROR("Plugin Op Cache: destroy aclInTensors " << i << " intArrayHostData failed."); + } + } + this->aclnnVariantPack.aclInTensors.clear(); + + for (size_t i = 0; i < this->aclnnVariantPack.aclOutTensors.size(); ++i) { + if (this->aclnnVariantPack.aclOutTensors[i]->tensorListidx == AclNNTensor::notInTensorList) { + ret = aclDestroyTensor(this->aclnnVariantPack.aclOutTensors[i]->tensor); + if (ret != 0) { ATB_SPEED_LOG_ERROR("Plugin Op Cache: destroy aclOutTensors " << i << " failed."); } + } + } + this->aclnnVariantPack.aclOutTensors.clear(); + + for (size_t i = 0; i < this->aclnnVariantPack.aclInTensorList.size(); ++i) { + ret = aclDestroyTensorList(this->aclnnVariantPack.aclInTensorList[i]); + if (ret != 0) { ATB_SPEED_LOG_ERROR("Plugin Op Cache: destroy aclInTensorList " << i << " failed."); } + } + this->aclnnVariantPack.aclInTensorList.clear(); + + for (size_t i = 0; i < this->aclnnVariantPack.aclOutTensorList.size(); ++i) { + ret = aclDestroyTensorList(this->aclnnVariantPack.aclOutTensorList[i]); + if (ret != 0) { ATB_SPEED_LOG_ERROR("Plugin Op Cache: destroy aclOutTensorList " << i << " failed."); } + } + this->aclnnVariantPack.aclOutTensorList.clear(); +} + +atb::Status AclNNOpCache::UpdateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG("call UpdateAclNNVariantPack "); + for (size_t i = 0; i < this->aclnnVariantPack.aclInTensors.size(); ++i) { + int ret = -1; + if (!this->aclnnVariantPack.aclInTensors[i]->needUpdateTensorDataPtr) { + continue; + } + this->aclnnVariantPack.aclInTensors[i]->atbTensor = variantPack.inTensors.at(i); + if (this->aclnnVariantPack.aclInTensors[i]->tensorListidx == AclNNTensor::notInTensorList) { + ret = aclSetInputTensorAddr(this->aclExecutor, + this->aclnnVariantPack.aclInTensors[i]->tensorIdx, + this->aclnnVariantPack.aclInTensors[i]->tensor, + this->aclnnVariantPack.aclInTensors[i]->atbTensor.deviceData); + } else { + ret = aclSetDynamicInputTensorAddr(this->aclExecutor, + this->aclnnVariantPack.aclInTensors[i]->tensorListidx, + this->aclnnVariantPack.aclInTensors[i]->tensorIdx, + this->aclnnVariantPack.aclInTensorList[this->aclnnVariantPack.aclInTensors[i]->tensorListidx], + this->aclnnVariantPack.aclInTensors[i]->atbTensor.deviceData); + } + if (ret != 0) { + ATB_SPEED_LOG_ERROR("inTensor " << i << " call UpdateAclTensorDataPtr fail, error: " << ret); + return atb::ERROR_CANN_ERROR; + } + } + + for (size_t i = 0; i < this->aclnnVariantPack.aclOutTensors.size(); ++i) { + int ret = -1; + if (!this->aclnnVariantPack.aclOutTensors[i]->needUpdateTensorDataPtr) { + continue; + } + this->aclnnVariantPack.aclOutTensors[i]->atbTensor = variantPack.outTensors.at(i); + if (this->aclnnVariantPack.aclOutTensors[i]->tensorListidx == AclNNTensor::notInTensorList) { + ret = aclSetOutputTensorAddr(this->aclExecutor, + this->aclnnVariantPack.aclOutTensors[i]->tensorIdx, + this->aclnnVariantPack.aclOutTensors[i]->tensor, + this->aclnnVariantPack.aclOutTensors[i]->atbTensor.deviceData); + } else { + ret = aclSetDynamicOutputTensorAddr(this->aclExecutor, + this->aclnnVariantPack.aclOutTensors[i]->tensorListidx, + this->aclnnVariantPack.aclOutTensors[i]->tensorIdx, + this->aclnnVariantPack.aclOutTensorList[this->aclnnVariantPack.aclOutTensors[i]->tensorListidx], + this->aclnnVariantPack.aclOutTensors[i]->atbTensor.deviceData); + } + if (ret != 0) { + ATB_SPEED_LOG_ERROR("outTensor " << i << " call UpdateAclTensorDataPtr fail, error: " << ret); + return atb::ERROR_CANN_ERROR; + } + } + + return atb::NO_ERROR; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation_cache.h b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation_cache.h new file mode 100644 index 00000000..919ba45b --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_operation_cache.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_NN_OPERATION_LOCAL_CACHE_H +#define ATB_SPEED_PLUGIN_ACLNN_NN_OPERATION_LOCAL_CACHE_H +#include "acl_nn_tensor.h" + +namespace atb_speed { +namespace common { + +/// Information about input and output tensors of an AclNN operation. +struct AclNNVariantPack { + /// A container stores an AclNN operation's in tensor in order. + /// Each `AclNNTensor` object contains one `aclTensor`. + atb::SVector> aclInTensors; + /// A container stores an AclNN operation's out tensor in order. + /// Each `AclNNTensor` object contains one `aclTensor`. + atb::SVector> aclOutTensors; + /// A container stores an AclNN operation's input `aclTensorList` in order. + /// Each `aclTensorList` object may contain multiple `aclTensor`. + atb::SVector aclInTensorList; + /// A container stores an AclNN operation's output `aclTensorList` in order. + /// Each `aclTensorList` object may contain multiple `aclTensor`. + atb::SVector aclOutTensorList; +}; + +/// AclNNOpCache stores information of an operation that can be reused between operations. +struct AclNNOpCache { + /// Information about input and output tensors of an AclNN operation. + AclNNVariantPack aclnnVariantPack; + /// AclNN operation's executor, which contains the operator computation process. + aclOpExecutor *aclExecutor = nullptr; + /// An indicator shows whether the `aclOpExecutor` is repeatable. + bool executorRepeatable = false; + /// Size of the workspace to be allocated on the device. + uint64_t workspaceSize; + /// Update the device memory address in `aclTensor` objects when the device memory changes. + /// + /// \param variantPack Information about input and output tensors of an AclNN operation. + /// \return A status code that indicates whether the update operation was successful. + atb::Status UpdateAclNNVariantPack(const atb::VariantPack &variantPack); + /// Destroy resources allocated in `AclNNOpCache`. + /// + /// Destory `aclOpExecutor` if it's repeatable and has no reference. + /// Destory `aclTensor` and `aclTensorList` if `aclOpExecutor` is destroyed. + void Destroy(); +}; + +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_tensor.h b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_tensor.h new file mode 100644 index 00000000..25e6b095 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/core/acl_nn_tensor.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_NN_TENSOR_H +#define ATB_SPEED_PLUGIN_ACLNN_NN_TENSOR_H +#include +#include +#include +#include + +namespace atb_speed { +namespace common { + +/// A struct contains the tensor's host data. +/// +/// This struct stores the tensor's host data in an int array format. +/// Host data is treat as an operation parameter and used during the setup phase to create the `aclOpExecutor`. +struct AclNNIntArray { + /// This struct is created by calling `aclCreateIntArray` and should be destroyed by calling `aclDestroyIntArray`. + /// It is used to create the `aclOpExecutor`. + aclIntArray* intArray = nullptr; + /// Data used to create the `aclIntArray*`. It is copied from atb::Tensor's hostData. + std::vector data = {}; + /// The size of `data` in bytes. + std::vector dataOri = {}; + uint64_t dataSize = 0; +}; + +/// A class contains tensor information. +/// +/// AclNN operations and ATB operations organize tensor in different format. +/// This class stores the information necessary for easy conversion and tensor usage. +class AclNNTensor { +public: + /// An const value to indicate that the `tensorListidx` is invalid. + static const int64_t notInTensorList = -1; + + /// Tensor passed through the ATB framework. + atb::Tensor atbTensor; + /// The stride of each dimension in the tensor's view shape. Used when creating `aclTensor`. + atb::SVector strides = {}; + /// Tensor passed into the AclNN operation. + aclTensor *tensor = nullptr; + /// An AclNNIntArray object contain tensor's host data in the int array format. + AclNNIntArray intArrayHostData; + /// The index of the tensor in the tensor list. Used when `aclTensor` is passed into `aclTensorList`. + int tensorListidx = notInTensorList; + /// The index of the tensor in `aclOpExecutor`'s parameter list. + int tensorIdx = -1; + /// An indicator that shows whether the tensor's device data needs to be updated in the execution. + bool needUpdateTensorDataPtr = false; +}; + +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/core/executor_manager.cpp b/tests/proftest/layer_test_framework/operations/aclnn/core/executor_manager.cpp new file mode 100644 index 00000000..0ca206ac --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/core/executor_manager.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "atb_speed/log.h" +#include "executor_manager.h" + +namespace atb_speed { +namespace common { + + +int ExecutorManager::IncreaseReference(aclOpExecutor *executor) +{ + std::map::iterator it = this->executorCount_.find(executor); + if (it == this->executorCount_.end()) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: Executor addr[" << executor << "] not found in ExecutorManager, add one"); + this->executorCount_[executor] = 1; + return 1; + } + + int &count = it->second; + count += 1; + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: ExecutorManager Executor addr[" + << executor << "] increase reference to " << count); + return count; +} + +int ExecutorManager::DecreaseReference(aclOpExecutor *executor) +{ + std::map::iterator it = this->executorCount_.find(executor); + if (it == this->executorCount_.end()) { + ATB_SPEED_LOG_ERROR("Plugin Op Cache: Executor addr[" << executor << "] not found in ExecutorManager"); + return 0; + } + int &count = it->second; + if (count == 1) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: delete Executor addr[" << executor << "]"); + this->executorCount_.erase(executor); + return 0; + } + + count -= 1; + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: ExecutorManager Executor addr[" + << executor << "] decrease reference to " << count); + return count; +} + +std::string ExecutorManager::PrintExecutorCount() +{ + std::stringstream ss; + ss << "Plugin Op Cache: Executor Summary "; + std::map::iterator it; + for (it = this->executorCount_.begin(); it != this->executorCount_.end(); it++) { + ss << "Executor Addr[" << it->first << "] count " << it->second << " "; + } + return ss.str(); +} + +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclnn/core/executor_manager.h b/tests/proftest/layer_test_framework/operations/aclnn/core/executor_manager.h new file mode 100644 index 00000000..6ad4a7d7 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/core/executor_manager.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_NN_OPERATION_EXECUTOR_MANAGER_H +#define ATB_SPEED_PLUGIN_ACLNN_NN_OPERATION_EXECUTOR_MANAGER_H + +#include +#include +#include + +namespace atb_speed { +namespace common { + +/// A class that manages `aclOpExecutor` objects and corresponding reference number. +/// +/// Each `AclNNOpCache` object has a unique `aclOpExecutor` object. +/// Since an `aclOpExecutor` object can be accessed through both local cache and global cache, +/// when destroying an `aclOpExecutor` object, it's important to ensure that it no longer has any references. +class ExecutorManager { +public: + /// Increase the reference count of the input `executor` by 1. + /// If the `executor` has no reference before, the reference count is set to 1. + /// + /// \param executor An `aclOpExecutor` object whose reference count needs to be increased. + /// \return The number of references after the increase. + int IncreaseReference(aclOpExecutor *executor); + /// Decrease the reference count of the input `executor` by 1. + /// If the `executor` has no reference after the decrease, it will be removed from `executorCount_`. + /// + /// \param executor An `aclOpExecutor` object whose reference count needs to be decreased. + /// \return The number of references after the decrease. + int DecreaseReference(aclOpExecutor *executor); + /// Print a summary of the objects stored in the `executorCount_`. + /// + /// The `aclOpExecutor`'s address and the corresponding reference number are printed. + /// + /// \return `aclOpExecutor` info. + std::string PrintExecutorCount(); + +private: + /// A map stores `aclOpExecutor` objects. + /// + /// Key is an `aclOpExecutor` object's address. Value is it's reference number. + std::map executorCount_; +}; + +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_dynamic_quant_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_dynamic_quant_operation.cpp new file mode 100644 index 00000000..666e3801 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_dynamic_quant_operation.cpp @@ -0,0 +1,161 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "aclnnop/aclnn_add_rms_norm_dynamic_quant.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "add_rms_norm_dynamic_quant_operation.h" + +namespace atb_speed { +namespace common { + +const double EPSILON_THRESHOLD = 1e-9; // 定义一个很小的阈值 + +AddRmsNormDynamicQuantOperation::AddRmsNormDynamicQuantOperation( + const std::string &name, double epsilon) : AclNNOperation(name) +{ + opName_ = name; + if (std::abs(epsilon) > EPSILON_THRESHOLD) { + epsilon_ = epsilon; + } +} + +AddRmsNormDynamicQuantOperation::~AddRmsNormDynamicQuantOperation() +{ + ATB_SPEED_LOG_DEBUG(opName_ << "AddRmsNormDynamicQuantOperation deconstructor"); + this->DestroyOperation(); +} + +atb::Status AddRmsNormDynamicQuantOperation::InferShape( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "AddRmsNormDynamicQuantOperation infer shape start"); + for (size_t i = 0; i < outTensorDescs.size(); i++) { + outTensorDescs.at(i).format = inTensorDescs.at(0).format; + if (i == 0 || i == NUM1) { // y1Out、y2Out输出dtype固定为INT8 + outTensorDescs.at(i).dtype = aclDataType::ACL_INT8; + } else if (i == NUM3 || i == NUM4) { // scale1Out、scale2Out:FLOAT32 + outTensorDescs.at(i).dtype = aclDataType::ACL_FLOAT; + } else { // xOut同x1输入的dtype + outTensorDescs.at(i).dtype = inTensorDescs.at(0).dtype; + } + // 不输入任何 smoothScale场景 + // y2Out、scale2Out 搞个1维即可, 占位, 内容无所谓 + if (i == NUM1 || i == NUM4) { + outTensorDescs.at(i).shape.dimNum = NUM1; + outTensorDescs.at(i).shape.dims[0] = 1; + } else if (i < NUM3) { + // y1Out、xOut输出支持2-8维, shape 同x1, x2 + outTensorDescs.at(i).shape.dimNum = inTensorDescs.at(1).shape.dimNum; + for (size_t j = 0; j < outTensorDescs.at(i).shape.dimNum; j++) { + outTensorDescs.at(i).shape.dims[j] = inTensorDescs.at(1).shape.dims[j]; + } + } else { + // scale1Out:shape维度为x的shape剔除最后一维 + outTensorDescs.at(i).shape.dimNum = inTensorDescs.at(1).shape.dimNum - 1; + for (size_t j = 0; j < outTensorDescs.at(i).shape.dimNum; j++) { + outTensorDescs.at(i).shape.dims[j] = inTensorDescs.at(1).shape.dims[j]; + } + } + } + + ATB_SPEED_LOG_DEBUG(opName_ << "AddRmsNormDynamicQuantOperation infer shape end"); + return 0; +} + +uint32_t AddRmsNormDynamicQuantOperation::GetInputNum() const { return NUM3; } + +uint32_t AddRmsNormDynamicQuantOperation::GetOutputNum() const { return NUM5; } + +int AddRmsNormDynamicQuantOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclTensor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + // Create aclInTensor + aclnnVariantPack.aclInTensors.resize(variantPack.inTensors.size()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + aclnnVariantPack.aclInTensors[i] = CreateTensor(variantPack.inTensors.at(i), i); + } + ATB_SPEED_LOG_DEBUG(opName_ << " Create aclInTensor end"); + // Create aclOutTensor + aclnnVariantPack.aclOutTensors.resize(variantPack.outTensors.size()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + aclnnVariantPack.aclOutTensors[i] = CreateTensor(variantPack.outTensors.at(i), i); + } + ATB_SPEED_LOG_DEBUG(opName_ << "Create aclOutTensor end; CreateAclTensor end"); + return 0; +} + +int AddRmsNormDynamicQuantOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddRmsNormQuantGetWorkspaceSize start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + // 不输入任何 smoothScale场景 + int ret = aclnnAddRmsNormDynamicQuantGetWorkspaceSize(aclnnVariantPack.aclInTensors.at(0)->tensor, // x1 + aclnnVariantPack.aclInTensors.at(1)->tensor, // x2 + aclnnVariantPack.aclInTensors.at(2)->tensor, // gamma(weight) + nullptr, // smoothScale1Optional + nullptr, // smoothScale2Optional + epsilon_, // epsilonOptional + aclnnVariantPack.aclOutTensors.at(0)->tensor, // y1Out + aclnnVariantPack.aclOutTensors.at(1)->tensor, // y2Out, shape为1, 占位, 内容无所谓 + aclnnVariantPack.aclOutTensors.at(2)->tensor, // xOut + aclnnVariantPack.aclOutTensors.at(3)->tensor, // scale1Out + aclnnVariantPack.aclOutTensors.at(4)->tensor, // scale2Out, shape为1, 占位, 内容无所谓 + &this->aclnnOpCache_->workspaceSize, // workspaceSize + &this->aclnnOpCache_->aclExecutor); // executor + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddRmsNormQuantGetWorkspaceSize end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize << ", aclExecutor:" + << this->aclnnOpCache_->aclExecutor); + + return ret; +} + +int AddRmsNormDynamicQuantOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddRmsNormDynamicQuant start"); + int ret = aclnnAddRmsNormDynamicQuant( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddRmsNormDynamicQuant end, ret:" << ret); + return ret; +} + +std::shared_ptr AddRmsNormDynamicQuantOperation::CreateTensor(atb::Tensor atbTensor, int tensorIdx) const +{ + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + aclnnTensor->tensorIdx = tensorIdx; + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, + atbTensor.desc.dtype, aclnnTensor->strides.data(), 0, + atbTensor.desc.format, atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, atbTensor.deviceData); + return aclnnTensor; +} +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_dynamic_quant_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_dynamic_quant_operation.h new file mode 100644 index 00000000..9f6d8168 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_dynamic_quant_operation.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_ADDRMSNORMDYNAMICQUANT_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_ADDRMSNORMDYNAMICQUANT_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { +class AddRmsNormDynamicQuantOperation : public AclNNOperation { +public: + explicit AddRmsNormDynamicQuantOperation(const std::string &name, double epsilon); + ~AddRmsNormDynamicQuantOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + double epsilon_ = 1e-6; + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx) const; +}; +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_operation.cpp new file mode 100644 index 00000000..cffe42f4 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_operation.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "aclnnop/aclnn_add_rms_norm.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "add_rms_norm_operation.h" + +namespace atb_speed { +namespace common { + +AddRmsNormOperation::AddRmsNormOperation(const std::string &name, float epsilon) : AclNNOperation(name) +{ + this->opName_ = name; + this->epsilon = epsilon; +} + +atb::Status AddRmsNormOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + for (size_t i = 0; i < outTensorDescs.size(); i++) { + outTensorDescs.at(i).format = inTensorDescs.at(0).format; + if (i == NUM1) { + outTensorDescs.at(i).dtype = aclDataType::ACL_FLOAT; + } else { + outTensorDescs.at(i).dtype = inTensorDescs.at(0).dtype; + } + + outTensorDescs.at(i).shape.dimNum = inTensorDescs.at(0).shape.dimNum; + + if (inTensorDescs.at(0).shape.dimNum == DIM3) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 3] CHECK W8A16_OP inputs shape: [input0]" + << inTensorDescs.at(0).shape.dims[DIM0] << ", " << inTensorDescs.at(0).shape.dims[DIM1] + << ", " << inTensorDescs.at(0).shape.dims[DIM2]); + outTensorDescs.at(i).shape.dims[DIM0] = inTensorDescs.at(0).shape.dims[DIM0]; + outTensorDescs.at(i).shape.dims[DIM1] = inTensorDescs.at(0).shape.dims[DIM1]; + outTensorDescs.at(i).shape.dims[DIM2] = inTensorDescs.at(0).shape.dims[DIM2]; + } else if (inTensorDescs.at(0).shape.dimNum == DIM2) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 2] CHECK W8A16_OP inputs shape: [input0]" + << inTensorDescs.at(0).shape.dims[DIM0] << ", " + << inTensorDescs.at(0).shape.dims[DIM1]); + outTensorDescs.at(i).shape.dims[DIM0] = inTensorDescs.at(0).shape.dims[DIM0]; + outTensorDescs.at(i).shape.dims[DIM1] = inTensorDescs.at(0).shape.dims[DIM1]; + if (i == NUM1) { + outTensorDescs.at(i).shape.dims[DIM1] = 1; + } + } else { + ATB_SPEED_LOG_ERROR(opName_ << " invalid dim num:" << inTensorDescs.at(DIM0).shape.dimNum); + } + } + + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return 0; +} + +uint32_t AddRmsNormOperation::GetInputNum() const { return NUM3; } + +uint32_t AddRmsNormOperation::GetOutputNum() const { return NUM3; } + +atb::Status AddRmsNormOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclTensor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(variantPack.inTensors.size()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + aclnnVariantPack.aclInTensors[i] = CreateTensor(variantPack.inTensors.at(i), i); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " Create aclInTensor end"); + + aclnnVariantPack.aclOutTensors.resize(variantPack.outTensors.size()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + aclnnVariantPack.aclOutTensors[i] = CreateTensor(variantPack.outTensors.at(i), i); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " Create aclOutTensor end"); + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclTensor end"); + return 0; +} + +int AddRmsNormOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddRmsNormGetWorkspaceSize start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnAddRmsNormGetWorkspaceSize(aclnnVariantPack.aclInTensors.at(0)->tensor, + aclnnVariantPack.aclInTensors.at(1)->tensor, + aclnnVariantPack.aclInTensors.at(2)->tensor, + this->epsilon, + aclnnVariantPack.aclOutTensors.at(0)->tensor, + aclnnVariantPack.aclOutTensors.at(1)->tensor, + aclnnVariantPack.aclOutTensors.at(2)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddRmsNormGetWorkspaceSize end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize << ", aclExecutor:" + << this->aclnnOpCache_->aclExecutor); + + return ret; +} + +int AddRmsNormOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddRmsNorm start"); + int ret = aclnnAddRmsNorm(workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddRmsNorm end, ret:" << ret); + return ret; +} + +std::shared_ptr AddRmsNormOperation::CreateTensor(atb::Tensor atbTensor, int tensorIdx) const +{ + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + aclnnTensor->tensorIdx = tensorIdx; + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensor); + return aclnnTensor; +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_operation.h new file mode 100644 index 00000000..ec0cc92a --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_operation.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_ADDRMSNORM_OPERATION_V2_H +#define ATB_SPEED_PLUGIN_ACLNN_ADDRMSNORM_OPERATION_V2_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { +/// This class defines a matrix operation combines the Add operator before the RmsNorm, +/// reducing the operations of moving data in and out. +/// +/// This class makes use of `aclnnAddRmsNormGetWorkspaceSize` and `aclnnAddRmsNorm` from the AscendCL API. +/// +/// Operation's Inputs: +/// Name | Dtype | Shape | +/// ----------------|-----------------------------|---------| +/// x1 | FLOAT, FLOAT16, BFLOAT16 | 1-8 dim | +/// x2 | FLOAT, FLOAT16, BFLOAT16 | 1-8 dim | +/// gamma | FLOAT, FLOAT16, BFLOAT16 | 1-8 dim | +/// epsilon | double | Scalar | +/// +/// Operations's Outputs: +/// Name | Dtype | Shape | +/// --------|-----------------------------|---------| +/// yOut | FLOAT, FLOAT16, BFLOAT16 | 1-8 dim | +/// rstdOut | FLOAT | 1-8 dim | +/// xOut | FLOAT, FLOAT16, BFLOAT16 | 1-8 dim | +/// +/// Example: +/// \code +/// enum InTensorIdx : uint32_t { +/// IN_INPUT1 = 0, +/// IN_INPUT2, +/// IN_WEIGHT, +/// }; +/// +/// enum OutTensorIdx : uint32_t { +/// OUT1 = 0, +/// OUT2, +/// OUT3, +/// }; +/// +/// atb::Node addNormNode; +/// addNormNode.operation = new atb_speed::common::AddRmsNormOperation("AddRmsNormNode", param.rmsNormEps); +/// addNormNode.outTensorIds = {OUT1, OUT2, OUT3}; +/// addNormNode.inTensorIds = {IN_INPUT1, IN_INPUT2, IN_WEIGHT}; +/// +/// // Add the operation node to the graph as required +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(addNormNode); +/// \endcode +class AddRmsNormOperation : public AclNNOperation { +public: + explicit AddRmsNormOperation(const std::string &name, float epsilon); + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + float epsilon = 1e-5; + atb::Status CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx) const; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_quant_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_quant_operation.cpp new file mode 100644 index 00000000..05f68619 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_quant_operation.cpp @@ -0,0 +1,166 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "aclnnop/aclnn_add_rms_norm_quant.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "add_rms_norm_quant_operation.h" + +namespace atb_speed { +namespace common { + +const double EPSILON_THRESHOLD = 1e-9; // 定义一个很小的阈值 + +AddRmsNormQuantOperation::AddRmsNormQuantOperation(const std::string &name, double epsilon) : AclNNOperation(name) +{ + opName_ = name; + if (std::abs(epsilon) > EPSILON_THRESHOLD) { + epsilon_ = epsilon; + } +} + +AddRmsNormQuantOperation::~AddRmsNormQuantOperation() +{ + ATB_SPEED_LOG_DEBUG(opName_ << "AddRmsNormQuantOperation deconstructor"); + this->DestroyOperation(); +} + +atb::Status AddRmsNormQuantOperation::InferShape( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "AddRmsNormQuantOperation infer shape start"); + for (size_t i = 0; i < outTensorDescs.size(); i++) { + outTensorDescs.at(i).format = inTensorDescs.at(0).format; + if (i == 0 || i == NUM1) { // y1Out、y2Out输出dtype固定为INT8 + outTensorDescs.at(i).dtype = aclDataType::ACL_INT8; + } else { // xOut同x1输入的dtype + outTensorDescs.at(i).dtype = inTensorDescs.at(0).dtype; + } + + // 输出支持1-8维, shape 同x1, x2 + outTensorDescs.at(i).shape.dimNum = inTensorDescs.at(1).shape.dimNum; + for (size_t j = 0; j < outTensorDescs.at(i).shape.dimNum; j++) { + outTensorDescs.at(i).shape.dims[j] = inTensorDescs.at(1).shape.dims[j]; + } + } + + ATB_SPEED_LOG_DEBUG(opName_ << "AddRmsNormQuantOperation infer shape end"); + return 0; +} + +uint32_t AddRmsNormQuantOperation::GetInputNum() const { return NUM5; } + +uint32_t AddRmsNormQuantOperation::GetOutputNum() const { return NUM3; } + +int AddRmsNormQuantOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclTensor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + // Create aclInTensor + aclnnVariantPack.aclInTensors.resize(variantPack.inTensors.size()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor atbTensor = variantPack.inTensors.at(i); + + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + + atb::Dims viewDims = atbTensor.desc.shape; + if (i == NUM4) { // zeroPoints1Optional fp16为:DT_INT32, bf16为:DT_BFLOAT16 + // tensorIdx与算子14个入参的idx一一对应, i只与外部输入的inTensors(5个)一致; + // 如果inTensors前有nullptr入参, 则要注意idx值与i值的匹配关系(不能tensorIdx有值,但算子入参给的是nullptr) + aclnnTensor->tensorIdx = NUM5; + } + aclnnTensor->tensor = aclCreateTensor( + viewDims.dims, atbTensor.desc.shape.dimNum, atbTensor.desc.dtype, + aclnnTensor->strides.data(), 0, atbTensor.desc.format, viewDims.dims, + atbTensor.desc.shape.dimNum, atbTensor.deviceData); + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + ATB_SPEED_LOG_DEBUG(opName_ << " Create aclInTensor end"); + // Create aclOutTensor + aclnnVariantPack.aclOutTensors.resize(variantPack.outTensors.size()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + aclnnVariantPack.aclOutTensors[i] = CreateTensor(variantPack.outTensors.at(i), i); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " Create aclOutTensor end; CreateAclTensor end"); + return 0; +} + +int AddRmsNormQuantOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddRmsNormQuantGetWorkspaceSize start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnAddRmsNormQuantGetWorkspaceSize(aclnnVariantPack.aclInTensors.at(0)->tensor, // x1 + aclnnVariantPack.aclInTensors.at(1)->tensor, // x2 + aclnnVariantPack.aclInTensors.at(2)->tensor, // gamma(weight) + aclnnVariantPack.aclInTensors.at(3)->tensor, // scales1 + nullptr, // scales2Optional -> 实际未使用 + aclnnVariantPack.aclInTensors.at(4)->tensor, // zeroPoints1Optional + nullptr, // zeroPoints2Optional -> 实际未使用 + -1, + epsilon_, // epsilonOptional + true, // divMode + aclnnVariantPack.aclOutTensors.at(0)->tensor, // y1Out + aclnnVariantPack.aclOutTensors.at(1)->tensor, // y2Out, shape为1, 内容无所谓 + aclnnVariantPack.aclOutTensors.at(2)->tensor, // xOut + &this->aclnnOpCache_->workspaceSize, // workspaceSize + &this->aclnnOpCache_->aclExecutor); // executor + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddRmsNormQuantGetWorkspaceSize end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize << ", aclExecutor:" + << this->aclnnOpCache_->aclExecutor); + + return ret; +} + +int AddRmsNormQuantOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddRmsNormQuant start"); + int ret = aclnnAddRmsNormQuant( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddRmsNormQuant end, ret:" << ret); + return ret; +} + +std::shared_ptr AddRmsNormQuantOperation::CreateTensor(atb::Tensor atbTensor, int tensorIdx) const +{ + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + aclnnTensor->tensorIdx = tensorIdx; + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, + atbTensor.desc.dtype, aclnnTensor->strides.data(), 0, + atbTensor.desc.format, atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, atbTensor.deviceData); + return aclnnTensor; +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_quant_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_quant_operation.h new file mode 100644 index 00000000..2025ef33 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/add_rms_norm_quant_operation.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_ADDRMSNORMQUANT_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_ADDRMSNORMQUANT_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { +class AddRmsNormQuantOperation : public AclNNOperation { +public: + explicit AddRmsNormQuantOperation(const std::string &name, double epsilon); + ~AddRmsNormQuantOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + double epsilon_ = 1e-6; + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx) const; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/argmax_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/argmax_operation.cpp new file mode 100644 index 00000000..a6e24499 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/argmax_operation.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "argmax_operation.h" +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "aclnnop/aclnn_argmax.h" +#include "atb_speed/log.h" + +namespace atb_speed { +namespace common { + +ArgMaxOperation::ArgMaxOperation(const std::string &name) : AclNNOperation(name) {} + +ArgMaxOperation::ArgMaxOperation(const std::string &name, atb_speed::common::AclNNArgMaxParam param) + : AclNNOperation(name), param_(param) +{ +} + +ArgMaxOperation::~ArgMaxOperation() +{ + ATB_SPEED_LOG_DEBUG("ArgMaxOperation deconstruct"); + this->DestroyOperation(); +} + +uint32_t ArgMaxOperation::GetInputNum() const { return NUM1; } + +uint32_t ArgMaxOperation::GetOutputNum() const { return NUM1; } + +atb::Status ArgMaxOperation::InferShape(const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ArgMaxOperation infer shape start"); + outTensorDesc.at(0).format = inTensorDesc.at(0).format; + outTensorDesc.at(0).dtype = ACL_INT32; + uint32_t inputDimNum = inTensorDesc.at(0).shape.dimNum; + uint32_t outputDimNum = inputDimNum; + uint32_t realDim = this->param_.dim < 0 ? this->param_.dim + inputDimNum : this->param_.dim; + + if (!param_.keepdim) { + outputDimNum -= 1; + } + outTensorDesc.at(0).shape.dimNum = outputDimNum; + + uint32_t j = 0; + for (uint32_t i = 0; i < outputDimNum; ++i) { + if (i == realDim && param_.keepdim) { + outTensorDesc.at(0).shape.dims[i] = 1; + j++; + } else { + outTensorDesc.at(0).shape.dims[j++] = inTensorDesc.at(0).shape.dims[i]; + } + } + + ATB_SPEED_LOG_DEBUG(opName_ << "ArgMaxOperation InferShape end"); + + return atb::NO_ERROR; +} + +int ArgMaxOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret; + + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +int ArgMaxOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.inTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + + +int ArgMaxOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclNnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclNnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclNnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.outTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " outTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclNnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +std::shared_ptr ArgMaxOperation::CreateTensor(atb::Tensor atbTensor, int tensorIdx) const +{ + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + aclnnTensor->tensorIdx = tensorIdx; + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor(atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, atbTensor.desc.dtype, + aclnnTensor->strides.data(), 0, atbTensor.desc.format, + atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, atbTensor.deviceData); + return aclnnTensor; +} + +int ArgMaxOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnArgMaxGetWorkspaceSize(aclnnVariantPack.aclInTensors.at(0)->tensor, this->param_.dim, + this->param_.keepdim, aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int ArgMaxOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnArgMax(workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end, ret:" << ret); + return ret; +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/argmax_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/argmax_operation.h new file mode 100644 index 00000000..7b7f01b9 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/argmax_operation.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDIE_LLM_PLUGIN_ACLNN_ARGMAX_OPERATION_H +#define MINDIE_LLM_PLUGIN_ACLNN_ARGMAX_OPERATION_H + +#include "aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { + +struct AclNNArgMaxParam { + int64_t dim = -1; + bool keepdim = false; +}; + +class ArgMaxOperation : public AclNNOperation { +public: + explicit ArgMaxOperation(const std::string &name); + explicit ArgMaxOperation(const std::string &name, AclNNArgMaxParam param); + ~ArgMaxOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx) const; + +private: + AclNNArgMaxParam param_; +}; +} // namespace common +} // namespace atb_speed + +#endif // MINDIE_LLM_PLUGIN_ACLNN_ARGMAX_OPERATION_H diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/argsort_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/argsort_operation.cpp new file mode 100644 index 00000000..103c07bf --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/argsort_operation.cpp @@ -0,0 +1,121 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "argsort_operation.h" +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "aclnnop/aclnn_argsort.h" +#include "operations/aclnn/utils/utils.h" +namespace atb_speed { +namespace common { + +ArgSortOperation::ArgSortOperation(const std::string &name) : AclNNOperation(name) { +} + +ArgSortOperation::~ArgSortOperation() { +} + +// 输入输出都只有一个,并且shape,format和dimnum都一致 +atb::Status ArgSortOperation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "ArgSortOperation infer shape start"); + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + outTensorDescs.at(0).dtype = ACL_INT64; + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum; + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(0).shape.dims[0]; + ATB_SPEED_LOG_DEBUG(opName_ << "ArgSortOperation infer shape end" + << " format: " << inTensorDescs.at(0).format << " dimNum: " << inTensorDescs.at(0).shape.dimNum + << " dims: " << inTensorDescs.at(0).shape.dims[0]); + return 0; +} + +uint32_t ArgSortOperation::GetInputNum() const +{ + return NUM1; +} + +uint32_t ArgSortOperation::GetOutputNum() const +{ + return NUM1; +} + +atb::Status ArgSortOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(variantPack.inTensors.at(i)); + + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(squeezedAtbTensor.desc.shape, squeezedAtbTensor.desc.shape, + squeezedAtbTensor, aclnnTensor)); + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +atb::Status ArgSortOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(variantPack.outTensors.at(i)); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(squeezedAtbTensor.desc.shape, squeezedAtbTensor.desc.shape, + squeezedAtbTensor, aclnnTensor)); + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int ArgSortOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnArgsortGetWorkspaceSize(aclnnVariantPack.aclInTensors.at(0)->tensor, + 0, + false, + aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int ArgSortOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnArgsort(workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end, ret:" << ret); + return ret; +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/argsort_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/argsort_operation.h new file mode 100644 index 00000000..34c4cc6b --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/argsort_operation.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_ARGSORT_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_ARGSORT_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" +namespace atb_speed { +namespace common { + +/// This class defines an operator that returns the indices that would sort the input. +/// +/// This class class makes uses of `aclnnArgsortGetWorkspaceSize` and `aclnnArgsort` from the AscendCL API. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// -------------|---------------------------------------------|-------| +/// input | float16, float32, int8, int32, int64, uint8 | [m,n] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// -------------|-------|-------| +/// output | int64 | [m,n] | +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// INPUT = 0, +/// OUT, +/// }; +/// +/// atb::Node &argsortNode = opGraph.nodes.at(nodeId++); +/// argsortNode.operation = new atb_speed::common::ArgSortOperation("ArgsortNode"); +/// argsortNode.inTensorIds = {INPUT}; +/// argsortNode.outTensorIds = {OUTPUT}; +/// \endcode + +class ArgSortOperation : public AclNNOperation { +public: + explicit ArgSortOperation(const std::string &name); + + ~ArgSortOperation() override; + + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + + uint32_t GetInputNum() const override; + + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; +}; + +} // namespace common +} // namespace atb_speed + +#endif // ATB_SPEED_PLUGIN_ACLNN_ARGSORT_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/attn_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/attn_operation.cpp new file mode 100644 index 00000000..c9f89518 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/attn_operation.cpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_fused_infer_attention_score_v2.h" +#include "atb/types.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "attn_operation.h" + +namespace atb_speed { +namespace common { + +AttnOperation::AttnOperation(const std::string &name, AclNNAttnParam param) : AclNNOperation(name), param_(param) +{ + tensorsOfValue[0] = nullptr; + tensorsOfKey[0] = nullptr; +} + +AttnOperation::~AttnOperation() +{ + tensorsOfKey[0] = nullptr; + tensorsOfValue[0] = nullptr; +} + +atb::Status AttnOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + // FA和 [B,S,H], PA [B,S,N,D] + outTensorDescs.at(0) = inTensorDescs.at(0); + if (!param_.isFA) { + outTensorDescs.at(0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; // B + outTensorDescs.at(0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM2]; // N + outTensorDescs.at(0).shape.dims[DIM2] = inTensorDescs.at(DIM0).shape.dims[DIM3]; // D + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum - 1; + } + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return 0; +} + +uint32_t AttnOperation::GetInputNum() const +{ + uint32_t inputNum = 6; + if (param_.hasKVQuant) { + ++inputNum; + if (param_.hasQuantOffset) { + ++inputNum; + } + } + return inputNum; +} + +uint32_t AttnOperation::GetOutputNum() const { return NUM1; } + +const int ACLNN_TENSOR_INDEX[8] = {0, 0, 0, 4, 6, 14, 12, 13}; +const int ACLNN_TENSOR_LIST_INDEX[8] = {-1, 1, 2, -1, -1, -1, -1, -1}; + +atb::Status AttnOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + if (param_.isFA && i == 5) { // 5: idx of block tables in in tensor + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + continue; + } + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + aclnnTensor->tensorIdx = ACLNN_TENSOR_INDEX[i]; + aclnnTensor->tensorListidx = ACLNN_TENSOR_LIST_INDEX[i]; + if (i == NUM4) { // 4:actual seqLength index + aclnnTensor->needUpdateTensorDataPtr = false; + aclnnTensor->intArrayHostData.dataSize = aclnnTensor->atbTensor.dataSize / NUM4; // int32 has 4 bytes + aclnnTensor->intArrayHostData.data.resize(aclnnTensor->intArrayHostData.dataSize); + aclnnTensor->intArrayHostData.dataOri.resize(aclnnTensor->intArrayHostData.dataSize); + std::transform( + static_cast(aclnnTensor->atbTensor.hostData), + static_cast(aclnnTensor->atbTensor.hostData) + aclnnTensor->atbTensor.dataSize / NUM4, + aclnnTensor->intArrayHostData.data.data(), [](int32_t value) { + return static_cast(value); + }); + std::copy(static_cast(aclnnTensor->atbTensor.hostData), + static_cast(aclnnTensor->atbTensor.hostData) + + aclnnTensor->atbTensor.dataSize / sizeof(int32_t), + aclnnTensor->intArrayHostData.dataOri.data()); + aclnnTensor->intArrayHostData.intArray = aclCreateIntArray( + static_cast(aclnnTensor->intArrayHostData.data.data()), + aclnnTensor->intArrayHostData.dataSize); + } else if (i == 3 && !param_.isFA) { // 3: idx of mask tensor + aclnnTensor->needUpdateTensorDataPtr = false; + } else { + aclnnTensor->needUpdateTensorDataPtr = true; + atb::Tensor atbTensor = variantPack.inTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN( + CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensor)); + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + tensorsOfKey[0] = aclnnVariantPack.aclInTensors.at(1)->tensor; // 1: key tensor index + tensorsOfValue[0] = aclnnVariantPack.aclInTensors.at(2)->tensor; // 2: value tensor index + auto tensorKeyList = aclCreateTensorList(tensorsOfKey, 1); + auto tensorValueList = aclCreateTensorList(tensorsOfValue, 1); + aclnnVariantPack.aclInTensorList.clear(); + aclnnVariantPack.aclInTensorList.push_back(nullptr); + aclnnVariantPack.aclInTensorList.push_back(tensorKeyList); + aclnnVariantPack.aclInTensorList.push_back(tensorValueList); + return atb::NO_ERROR; +} + +atb::Status AttnOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(NUM1); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->tensorListidx = AclNNTensor::notInTensorList; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor squeezedAtbTensor = variantPack.outTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(squeezedAtbTensor.desc.shape, squeezedAtbTensor.desc.shape, + squeezedAtbTensor, aclnnTensor)); + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int AttnOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAttnGetWorkspaceSize start"); + char inputLayoutFA[5] = "BSH"; + char inputLayoutPA[5] = "BSND"; + double scaleValue = 1 / sqrt(param_.headDim); + AclNNVariantPack &task = this->aclnnOpCache_->aclnnVariantPack; + aclTensor *maskTensor = param_.isFA ? task.aclInTensors.at(3)->tensor : nullptr; // 3: attention mask tensor index + aclTensor *blockTensor = param_.isFA ? nullptr : task.aclInTensors.at(5)->tensor; // 5: blocktable tensor index + aclTensor *antiquantScaleTensor = + param_.hasKVQuant ? task.aclInTensors.at(6)->tensor : nullptr; // 6: dequantOffset tensor index + aclTensor *antiquantOffsetTensor = param_.hasKVQuant && param_.hasQuantOffset ? task.aclInTensors.at(7)->tensor + : nullptr; // 7: dequantOffset index + // query - 0; key - 1; value - 2; pseShift - 3; attenMask - 4; actualSeqLengths - 5; + // ++1 actualSeqLengthsKv - 6; + // dequantScale1 - 6; quantScale1 - 7; dequantScale2 - 8; quantScale2 - 9; + // quantScale2 - 10; antiquantScale - 11; antiquantOffset - 12; blocktable - 13; + // numHeads - 14; scaleValue - 15; inputLayout - 16; numKeyValueHeads - 17; + // blockSize - 18; innerPrecise - 19; + // innerPrecise - 20; workspaceSize - 21; workspaceSize - 22 + int ret = aclnnFusedInferAttentionScoreV2GetWorkspaceSize( + task.aclInTensors.at(0)->tensor, // 0: query index + task.aclInTensorList.at(1), // 1: key cache index + task.aclInTensorList.at(2), // 2: value cache index + nullptr, maskTensor, // 4: attenMask + nullptr, task.aclInTensors.at(4)->intArrayHostData.intArray, // 6: seq length index + nullptr, nullptr, nullptr, nullptr, nullptr, + antiquantScaleTensor, // 12: antiquantScale + antiquantOffsetTensor, // 13: antiquantOffset + blockTensor, // 14: blocktable + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + param_.headNum, scaleValue, 2147483647, 2147483647, + param_.isFA ? inputLayoutFA : inputLayoutPA, + param_.kvHeadNum, 0, param_.innerPrecise, + param_.isFA ? 0 : param_.blockSize, 0, false, 0, 0, + task.aclOutTensors.at(0)->tensor, // 0: out tensor + nullptr, &this->aclnnOpCache_->workspaceSize, &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAttnGetWorkspaceSize end, ret:" << ret << + ", workspaceSize:" << this->aclnnOpCache_->workspaceSize << + ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int AttnOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + return aclnnFusedInferAttentionScoreV2(workspace, this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, stream); +} +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/attn_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/attn_operation.h new file mode 100644 index 00000000..dc50e6c0 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/attn_operation.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATTN_OPERATION_H +#define ATTN_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" +#include "cstring" +namespace atb_speed { +namespace common { +struct AclNNAttnParam { + /// A flag indicating whether the model use mask + bool hasMask = false; + /// A flag indicating whether the model use FA + bool isFA = false; + /// A flag indicating whether the model prefills + bool isPrefill = false; + /// A flag indicating whether the model is kvcache int8 compressed + bool hasKVQuant = false; + /// A flag indicating whether the model has kvcache compressed offset weight + bool hasQuantOffset = false; + /// enable Prefix Attn + bool enablePrefixAttn = false; + /// the number of head + int64_t headNum = 0; + /// the number of kvHead + int64_t kvHeadNum = 0; + /// the number of headDim + int64_t headDim = 0; + /// represent high performance/accuracy, dafault 1 (high performance) + int64_t innerPrecise = 1; + /// max number of tokens in each block page attention stored in KV cache + int64_t blockSize = 128; +}; + +/// This class defines an operator that calculates the attention including FA and PA. +/// +/// This class makes uses of `aclnnFusedInferAttentionScoreV2GetWorkspaceSize` and +/// `aclnnFusedInferAttentionScoreV2` from AscendCL Api. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// --------------------------|-----------------------------|-------------------------------------| +/// input | * | [batchsize, headNum, dim] | +/// query | float16, bfloat16 or int8 | [batchsize, headNum, dim] | +/// key | float16, bfloat16 or int8 | [blocknum, blocksize, headNum, dim] | +/// value | float16, bfloat16 or int8 | [blocknum, blocksize, headNum, dim] | +/// actualSeqLengthsOptional | int64 | [bs] | +/// blockTableOptional | float16, bfloat16 or float32| [bs,blocknum] | +/// antiquantScaleOptional | float16, bfloat16 or float32| [bs,dim] | +/// antiquantOffsetOptional | float16, bfloat16 or float32| [bs,dim] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// --------------------------|-----------------------------|-------------------------------------| +/// output | float16, bfloat16 or int8 | [batchsize, headNum, dim] | +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// QUERY, +/// KEY, +/// VALUE, +/// SEQ_LEN, +/// BLOCK_TABLE, +/// DEQUANT_SCALE, +/// DEQUANT_OFFSET, +/// OUT, +///}; +/// +/// atb::Node &attnNode = opGraph.nodes.at(nodeId++); +/// attnNode.operation = new atb_speed::common::AttnOperation("AttentionNode"); +/// attnNode.inTensorIds = {QUERY, KEY, VALUE, SEQ_LEN, BLOCK_TABLE, DEQUANT_SCALE, DEQUANT_OFFSET}; +/// attnNode.outTensorIds = {OUT}; +/// \endcode + +class AttnOperation : public AclNNOperation { +public: + explicit AttnOperation(const std::string &name, AclNNAttnParam param); + ~AttnOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + +private: + int ProcessSeqLengthTensor(atb::Tensor &tensor); + +private: + aclTensor *tensorsOfValue[1]{nullptr}; + aclTensor *tensorsOfKey[1]{nullptr}; + AclNNAttnParam param_; + std::string opName_; +}; +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/cast_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/cast_operation.cpp new file mode 100644 index 00000000..e14a7c52 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/cast_operation.cpp @@ -0,0 +1,166 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "operations/aclnn/core/acl_nn_operation.h" +#include "aclnnop/aclnn_cast.h" +#include "cast_operation.h" + +namespace atb_speed { +namespace common { + +CastOperation::CastOperation(const std::string &name, AclNNCastParam param) + : AclNNOperation(name), param_(param) {} + +CastOperation::~CastOperation() +{ + ATB_SPEED_LOG_DEBUG("CastOperation deconstructor"); + this->DestroyOperation(); +} + +constexpr int MAX_DIMENSION = 8; + +atb::Status CastOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + + if (inTensorDescs.at(0).shape.dimNum > MAX_DIMENSION) { + ATB_SPEED_LOG_ERROR(opName_ << " tensor dimension exceeds limit"); + return atb::ERROR_INVALID_PARAM; + } + + outTensorDescs.at(0).shape = inTensorDescs.at(0).shape; + outTensorDescs.at(0).dtype = param_.dtype; + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return atb::NO_ERROR; +} + +uint32_t CastOperation::GetInputNum() const { return NUM1; } + +uint32_t CastOperation::GetOutputNum() const { return NUM1; } + +int CastOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + + int ret = 0; + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(opName_ << " CreateAclNNInTensorVariantPack failed"); + return ret; + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(opName_ << " CreateAclNNOutTensorVariantPack failed"); + return ret; + } + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +atb::Status CastOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = 0; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(0); + aclnnTensor->strides = GetCopyTensorStride(aclnnTensor->atbTensor.desc.shape); + + aclnnTensor->tensor = aclCreateTensor( + aclnnTensor->atbTensor.desc.shape.dims, aclnnTensor->atbTensor.desc.shape.dimNum, + aclnnTensor->atbTensor.desc.dtype, aclnnTensor->strides.data(), 0, + aclnnTensor->atbTensor.desc.format, aclnnTensor->atbTensor.desc.shape.dims, + aclnnTensor->atbTensor.desc.shape.dimNum, aclnnTensor->atbTensor.deviceData); + + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(opName_ << " Create input tensor failed"); + return atb::ERROR_INTERNAL_ERROR; + } + + aclnnVariantPack.aclInTensors[0] = aclnnTensor; + return atb::NO_ERROR; +} + +atb::Status CastOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = 0; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(0); + aclnnTensor->strides = GetCopyTensorStride(aclnnTensor->atbTensor.desc.shape); + + aclnnTensor->tensor = aclCreateTensor( + aclnnTensor->atbTensor.desc.shape.dims, aclnnTensor->atbTensor.desc.shape.dimNum, + param_.dtype, aclnnTensor->strides.data(), 0, + aclnnTensor->atbTensor.desc.format, aclnnTensor->atbTensor.desc.shape.dims, + aclnnTensor->atbTensor.desc.shape.dimNum, aclnnTensor->atbTensor.deviceData); + + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(opName_ << " Create output tensor failed"); + return atb::ERROR_INTERNAL_ERROR; + } + + aclnnVariantPack.aclOutTensors[0] = aclnnTensor; + return atb::NO_ERROR; +} + +int CastOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + + int ret = aclnnCastGetWorkspaceSize( + aclnnVariantPack.aclInTensors[0]->tensor, + param_.dtype, + aclnnVariantPack.aclOutTensors[0]->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + if (ret != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR(opName_ << " GetWorkspaceSize failed with error code: " << ret); + return ret; + } + + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end"); + return atb::NO_ERROR; +} + +int CastOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + int ret = aclnnCast( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + if (ret != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR(opName_ << " ExecuteAclNNOp failed"); + } + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/cast_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/cast_operation.h new file mode 100644 index 00000000..ce3dcd10 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/cast_operation.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_CAST_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_CAST_OPERATION_H + +#include +#include "acl/acl.h" +#include "operations/aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { + +struct AclNNCastParam { + aclDataType dtype; +}; + +class CastOperation : public AclNNOperation { +public: + explicit CastOperation(const std::string &name, AclNNCastParam param); + ~CastOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + +private: + AclNNCastParam param_; +}; + +} // namespace common +} // namespace atb_speed + +#endif // ATB_SPEED_PLUGIN_ACLNN_CAST_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/concat_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/concat_operation.cpp new file mode 100644 index 00000000..5442c145 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/concat_operation.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "concat_operation.h" +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "aclnnop/aclnn_cat.h" +#include "atb_speed/log.h" + +namespace atb_speed { +namespace common { + +ConcatOperation::ConcatOperation(const std::string &name, atb_speed::common::AclNNConcatParam param) + : AclNNOperation(name), param_(param) {} + +ConcatOperation::~ConcatOperation() +{ + ATB_SPEED_LOG_DEBUG("ConcatOperation deconstruct"); + this->DestroyOperation(); +} + +uint32_t ConcatOperation::GetInputNum() const { return NUM2; } + +uint32_t ConcatOperation::GetOutputNum() const { return NUM1; } + +atb::Status ConcatOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ConcatOperation infer shape start"); + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + outTensorDescs.at(0).shape = inTensorDescs.at(0).shape; + outTensorDescs.at(0).shape.dims[this->param_.dim] = inTensorDescs.at(0).shape.dims[this->param_.dim] + \ + inTensorDescs.at(1).shape.dims[this->param_.dim]; + ATB_SPEED_LOG_DEBUG(opName_ << "ConcatOperation InferShape end"); + return atb::NO_ERROR; +} + +int ConcatOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + + int ret = 0; + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(opName_ << " CreateAclNNInTensorVariantPack failed"); + return ret; + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(opName_ << " CreateAclNNOutTensorVariantPack failed"); + return ret; + } + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +int ConcatOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.inTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int ConcatOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclNnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclNnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclNnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.outTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " outTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclNnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +std::shared_ptr ConcatOperation::CreateTensor(atb::Tensor atbTensor, int tensorIdx) const +{ + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + aclnnTensor->tensorIdx = tensorIdx; + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor(atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, + atbTensor.desc.dtype, aclnnTensor->strides.data(), + 0, atbTensor.desc.format, atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, atbTensor.deviceData); + return aclnnTensor; +} + +int ConcatOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + std::vector tmp{aclnnVariantPack.aclInTensors.at(0)->tensor, \ + aclnnVariantPack.aclInTensors.at(1)->tensor}; + aclTensorList* tensorList = aclCreateTensorList(tmp.data(), tmp.size()); + int ret = aclnnCatGetWorkspaceSize(tensorList, this->param_.dim, + aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor address: " << &this->aclnnOpCache_->aclExecutor + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int ConcatOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnCat(workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end, ret:" << ret); + return ret; +} + +} +} \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/concat_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/concat_operation.h new file mode 100644 index 00000000..897ece84 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/concat_operation.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDIE_LLM_PLUGIN_ACLNN_CONCAT_OPERATION_H +#define MINDIE_LLM_PLUGIN_ACLNN_CONCAT_OPERATION_H + +#include "aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { + +struct AclNNConcatParam { + int64_t dim = -1; +}; + +class ConcatOperation : public AclNNOperation { +public: + explicit ConcatOperation(const std::string &name, AclNNConcatParam param); + ~ConcatOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx) const; + +private: + AclNNConcatParam param_; +}; + +} // namespace common +} // namespace atb_speed + +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_rope_quant_kvcache_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_rope_quant_kvcache_operation.cpp new file mode 100644 index 00000000..15ddca65 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_rope_quant_kvcache_operation.cpp @@ -0,0 +1,190 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "aclnnop/aclnn_dequant_rope_quant_kvcache.h" +#include "dequant_rope_quant_kvcache_operation.h" + +namespace atb_speed { +namespace common { + +DequantRopeQuantKvcacheOperation::DequantRopeQuantKvcacheOperation( + const std::string &name, + AclNNDequantRopeQuantKvcacheParam param) : AclNNOperation(name), param_(param) {} + +DequantRopeQuantKvcacheOperation::~DequantRopeQuantKvcacheOperation() +{ + ATB_SPEED_LOG_DEBUG("DequantRopeQuantKvcacheOperation deconstructor"); + this->DestroyOperation(); +} + +atb::Status DequantRopeQuantKvcacheOperation::InferShape( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + + for (int i = 0; i < NUM3; ++i) { + outTensorDescs.at(i).format = inTensorDescs.at(0).format; + outTensorDescs.at(i).dtype = inTensorDescs.at(1).dtype; + outTensorDescs.at(i).shape.dimNum = NUM3; + } + + const int64_t batchsize = inTensorDescs.at(DIM0).shape.dims[DIM0]; // x.size(0) [1024, 1280] + const int64_t kvHeaddim = inTensorDescs.at(NUM4).shape.dims[DIM2]; // v_cache_ref.size(2); // 1 + const int64_t dim = inTensorDescs.at(NUM4).shape.dims[DIM3]; // v_cache_ref.size(3); // [9, 128, 1, 128] + const int64_t qHeaddim = (dim == 0) ? 0 : + (inTensorDescs.at(DIM0).shape.dims[DIM1] - kvHeaddim * dim * NUM2) / dim; + + for (int i = 0; i < NUM3; ++i) { + outTensorDescs.at(i).shape.dims[DIM0] = batchsize; + outTensorDescs.at(i).shape.dims[DIM1] = (i > 0) ? kvHeaddim : qHeaddim; + outTensorDescs.at(i).shape.dims[DIM2] = dim; + } + + // 打印 aclInTensors 地址信息 + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + ATB_SPEED_LOG_DEBUG("Input tensor[" << i << "] address: " + << aclnnVariantPack.aclInTensors.at(i)->tensor); + } + + // 打印 aclOutTensors 地址信息 + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + ATB_SPEED_LOG_DEBUG("Output tensor[" << i << "] address: " + << aclnnVariantPack.aclOutTensors.at(i)->tensor); + } + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return 0; +} + +uint32_t DequantRopeQuantKvcacheOperation::GetInputNum() const +{ + return param_.enableDequant ? 12 : 10; // 外抛dequant: 12; 不外抛dequant: 10 +} + +uint32_t DequantRopeQuantKvcacheOperation::GetOutputNum() const { return NUM3; } + +int DequantRopeQuantKvcacheOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + if (i == 11) { // bias: 11 + aclnnTensor->tensorIdx++; + } + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor atbTensor = variantPack.inTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, atbTensor.desc.dtype, + aclnnTensor->strides.data(), 0, atbTensor.desc.format, atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, atbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int DequantRopeQuantKvcacheOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + aclnnVariantPack.aclOutTensors[i] = CreateTensor(variantPack.outTensors.at(i), i); + if (aclnnVariantPack.aclOutTensors[i]->tensor == nullptr) { + return atb::ERROR_INTERNAL_ERROR; + } + } + return atb::NO_ERROR; +} + +int DequantRopeQuantKvcacheOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret = 0; + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + return ret; + } + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +int DequantRopeQuantKvcacheOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + char cacheMode[5] = "page"; + int ret = aclnnDequantRopeQuantKvcacheGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // 0: x + aclnnVariantPack.aclInTensors.at(1)->tensor, // 1: cos + aclnnVariantPack.aclInTensors.at(2)->tensor, // 2: sin + aclnnVariantPack.aclInTensors.at(3)->tensor, // 3: k_cache + aclnnVariantPack.aclInTensors.at(4)->tensor, // 4: v_cache + aclnnVariantPack.aclInTensors.at(5)->tensor, // 5: indices + aclnnVariantPack.aclInTensors.at(6)->tensor, // 6: scale_k + aclnnVariantPack.aclInTensors.at(7)->tensor, // 7: scale_v + aclnnVariantPack.aclInTensors.at(8)->tensor, // 8: offset_k + aclnnVariantPack.aclInTensors.at(9)->tensor, // 9: offset_v + param_.enableDequant ? aclnnVariantPack.aclInTensors.at(10)->tensor : nullptr, // 10: weight_scale + nullptr, // 11: activation_scale + param_.enableDequant ? aclnnVariantPack.aclInTensors.at(11)->tensor : nullptr, // 12: bias + aclCreateIntArray(param_.sizeSpilts.data(), param_.sizeSpilts.size()), // 13: sizeSpilts + const_cast(param_.quantMode.c_str()), // 14: quantMode, char + const_cast(param_.layout.c_str()), // 15: layoutOptional + param_.kvOutput, // 16: kvOutputOptional + cacheMode, // 17: cachemode + aclnnVariantPack.aclOutTensors.at(0)->tensor, // 18: qOut + aclnnVariantPack.aclOutTensors.at(1)->tensor, // 19: kOut + aclnnVariantPack.aclOutTensors.at(2)->tensor, // 20: vOut + &this->aclnnOpCache_->workspaceSize, // 21: workspaceSize + &this->aclnnOpCache_->aclExecutor); // 22: executor + + ATB_SPEED_LOG_DEBUG(opName_ << " end, ret:" + << ret << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int DequantRopeQuantKvcacheOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + int ret = aclnnDequantRopeQuantKvcache( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + if (ret != 0) { + ATB_SPEED_LOG_ERROR("!!!!!!!!!!!! aclnnDequantRopeQuantKvcache failed, ret: " << ret); + } + return ret; +} +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_rope_quant_kvcache_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_rope_quant_kvcache_operation.h new file mode 100644 index 00000000..44a6c537 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_rope_quant_kvcache_operation.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_DEQUANT_ROPE_QUANT_KVCACHE_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_DEQUANT_ROPE_QUANT_KVCACHE_OPERATION_H +#include "acl/acl.h" +#include "operations/aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { +struct AclNNDequantRopeQuantKvcacheParam { + std::vector sizeSpilts = {128 * 8, 128, 128}; + bool kvOutput = true; + std::string quantMode = "static"; + std::string layout = "BSND"; + bool enableDequant = false; +}; + +class DequantRopeQuantKvcacheOperation : public AclNNOperation { +public: + explicit DequantRopeQuantKvcacheOperation(const std::string &name, AclNNDequantRopeQuantKvcacheParam param); + ~DequantRopeQuantKvcacheOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + +private: + AclNNDequantRopeQuantKvcacheParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_swiglu_quant_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_swiglu_quant_operation.cpp new file mode 100644 index 00000000..4cd7f9fa --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_swiglu_quant_operation.cpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "aclnnop/aclnn_dequant_swiglu_quant.h" +#include "dequant_swiglu_quant_operation.h" + +namespace atb_speed { +namespace common { + +DequantSwigluQuantOperation::DequantSwigluQuantOperation( + const std::string &name, + AclNNDequantSwigluQuantParam param) : AclNNOperation(name), param_(param) {} + +DequantSwigluQuantOperation::~DequantSwigluQuantOperation() +{ + ATB_SPEED_LOG_DEBUG("DequantSwigluQuantOperation deconstructor"); + this->DestroyOperation(); +} + +atb::Status DequantSwigluQuantOperation::InferShape( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = aclDataType::ACL_INT8; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + outTensorDescs.at(DIM1).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM1).dtype = aclDataType::ACL_FLOAT; + outTensorDescs.at(DIM1).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + if (inTensorDescs.at(DIM0).shape.dimNum == DIM3) { + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + outTensorDescs.at(DIM0).shape.dims[DIM2] = inTensorDescs.at(DIM0).shape.dims[DIM2] / NUM2; + outTensorDescs.at(DIM1).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM1).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + outTensorDescs.at(DIM1).shape.dims[DIM2] = NUM1; + } else if (inTensorDescs.at(DIM0).shape.dimNum == DIM2) { + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1] / NUM2; + outTensorDescs.at(DIM1).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM1).shape.dims[DIM1] = NUM1; + } else { + ATB_SPEED_LOG_ERROR(opName_ << " invalid dim num:" << inTensorDescs.at(DIM0).shape.dimNum); + } + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return 0; +} + + +uint32_t DequantSwigluQuantOperation::GetInputNum() const +{ + if (param_.inTensorsNum < NUM1 || param_.inTensorsNum > NUM5) { + ATB_SPEED_LOG_DEBUG("DequantSwigluQuantOperation param inTensorsNum is wrong! reset to 5."); + return NUM5; + } + return param_.inTensorsNum; +} + +uint32_t DequantSwigluQuantOperation::GetOutputNum() const { return NUM2; } + +int DequantSwigluQuantOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + if (param_.inTensorsNum == NUM3 && i > 0) { + aclnnTensor->tensorIdx = i + NUM3; + } else if (param_.inTensorsNum == NUM5 && i > 1) { + aclnnTensor->tensorIdx = i + 1; + } else { + aclnnTensor->tensorIdx = i; + } + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor atbTensor = variantPack.inTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, atbTensor.desc.dtype, + aclnnTensor->strides.data(), 0, atbTensor.desc.format, atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, atbTensor.deviceData); + + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int DequantSwigluQuantOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor atbTensor = variantPack.outTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, atbTensor.desc.dtype, + aclnnTensor->strides.data(), 0, atbTensor.desc.format, atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, atbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " OutTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int DequantSwigluQuantOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret = 0; + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + return ret; + } + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +int DequantSwigluQuantOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + uint32_t inputIdx = 0; + aclTensor* xTensor = aclnnVariantPack.aclInTensors.at(inputIdx++)->tensor; + aclTensor* weightScaleTensor = (param_.inTensorsNum > NUM3) ? + aclnnVariantPack.aclInTensors.at(inputIdx++)->tensor : nullptr; + aclTensor* biasTensor = (param_.inTensorsNum > NUM3) ? + aclnnVariantPack.aclInTensors.at(inputIdx++)->tensor : nullptr; + aclTensor* quantScaleTensor = (param_.inTensorsNum > NUM1) ? + aclnnVariantPack.aclInTensors.at(inputIdx++)->tensor : nullptr; + aclTensor* quantOffsetTensor = (param_.inTensorsNum > NUM1) ? + aclnnVariantPack.aclInTensors.at(inputIdx++)->tensor : nullptr; + + int ret = aclnnDequantSwigluQuantGetWorkspaceSize( + xTensor, // x + weightScaleTensor, // weightScaleOptional + nullptr, // activationScaleOptional + biasTensor, // biasOptional + quantScaleTensor, // quantScaleOptional + quantOffsetTensor, // quantOffsetOptional + nullptr, // groupIndexOptional + param_.activateLeft, // activateLeft + const_cast(param_.quantMode.c_str()), // quantMode, char + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, // y + aclnnVariantPack.aclOutTensors.at(DIM1)->tensor, // scaleOptional + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " end, ret:" + << ret << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int DequantSwigluQuantOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " DequantSwigluQuantOperation start"); + int ret = aclnnDequantSwigluQuant( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " DequantSwigluQuantOperation end, ret: " << ret); + return ret; +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_swiglu_quant_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_swiglu_quant_operation.h new file mode 100644 index 00000000..46687245 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/dequant_swiglu_quant_operation.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_DEQUANT_SWIGLU_QUANT_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_DEQUANT_SWIGLU_QUANT_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/utils/utils.h" + +namespace atb_speed { +namespace common { +struct AclNNDequantSwigluQuantParam { + bool activateLeft = false; + std::string quantMode = "static"; + int inTensorsNum = NUM3; +}; + +class DequantSwigluQuantOperation : public AclNNOperation { +public: + explicit DequantSwigluQuantOperation(const std::string &name, AclNNDequantSwigluQuantParam param); + ~DequantSwigluQuantOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + +private: + AclNNDequantSwigluQuantParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/dynamic_quant_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/dynamic_quant_operation.cpp new file mode 100644 index 00000000..5a4b1ecf --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/dynamic_quant_operation.cpp @@ -0,0 +1,98 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dynamic_quant_operation.h" +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_dynamic_quant.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" + +namespace atb_speed { +namespace common { + +DynamicQuantOperation::DynamicQuantOperation(const std::string &name) : AclNNOperation(name) {} +DynamicQuantOperation::~DynamicQuantOperation() {} + +atb::Status DynamicQuantOperation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "DynamicQuantOperation infer shape start"); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = aclDataType::ACL_INT8; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + outTensorDescs.at(DIM1).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM1).dtype = aclDataType::ACL_FLOAT; + outTensorDescs.at(DIM1).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum - 1; + + ATB_SPEED_LOG_DEBUG(opName_ << "DynamicQuantOperation infer shape origin inTensorDescs.at(DIM0).shape.dims[DIM0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0]); + + for (uint64_t i = 0; i < inTensorDescs.at(DIM0).shape.dimNum; i++) { + outTensorDescs.at(DIM0).shape.dims[i] = inTensorDescs.at(DIM0).shape.dims[i]; + } + for (uint64_t i = 0; i < inTensorDescs.at(DIM0).shape.dimNum - 1; i++) { + outTensorDescs.at(DIM1).shape.dims[i] = inTensorDescs.at(DIM0).shape.dims[i]; + } + + ATB_SPEED_LOG_DEBUG(opName_ << "DynamicQuantOperation infer shape end"); + return 0; +} + +uint32_t DynamicQuantOperation::GetInputNum() const +{ + return DIM1; +} + +uint32_t DynamicQuantOperation::GetOutputNum() const +{ + return DIM2; +} + +int DynamicQuantOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " DynamicQuantOperation start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnDynamicQuantGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + nullptr, + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM1)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int DynamicQuantOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " DynamicQuantOperation start"); + + int ret = aclnnDynamicQuant( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " DynamicQuantOperation end, ret:" << ret); + return ret; +} + +} +} \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/dynamic_quant_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/dynamic_quant_operation.h new file mode 100644 index 00000000..d5154ae8 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/dynamic_quant_operation.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_DYNAMIC_QUANT_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_DYNAMIC_QUANT_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" + +/// This class defines an dynamic quant operator. +/// +/// This class makes uses of `aclnnDynamicQuantGetWorkspaceSize` and `aclnnDynamicQuant` +/// from the AscendCL API. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// ---------------|---------------------|---------| +/// input | float16 or bfloat16 | [m,h] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// ---------------|---------------------|---------| +/// out | int8 | [m,h] | +/// tokenScales | float16 or bfloat16 | [m] | +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// OUT, +/// OUT_SCALE, +/// }; +/// +/// atb::Node &dynamicQuantNode = opGraph.nodes.at(nodeId++); +/// dynamicQuantNode.operation = new atb_speed::common::DynamicQuantOperation("DynamicQuantOperation"); +/// dynamicQuantNode.inTensorIds = {IN_INPUT}; +/// dynamicQuantNode.outTensorIds = {OUT, OUT_SCALE}; +/// \endcode + +namespace atb_speed::common { + +class DynamicQuantOperation : public AclNNOperation { +public: + explicit DynamicQuantOperation(const std::string &name); + ~DynamicQuantOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; +}; +} + +#endif // ATB_SPEED_PLUGIN_ACLNN_DYNAMIC_QUANT_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/finalize_routing_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/finalize_routing_operation.cpp new file mode 100644 index 00000000..c5861774 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/finalize_routing_operation.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" +#include "aclnnop/aclnn_moe_finalize_routing.h" +#include "finalize_routing_operation.h" + +namespace atb_speed { +namespace common { + +FinalizeRoutingOperation::FinalizeRoutingOperation(const std::string &name) : AclNNOperation(name) {} +FinalizeRoutingOperation::~FinalizeRoutingOperation() {} + +atb::Status FinalizeRoutingOperation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "FinalizeRoutingOperation infer shape start"); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM1).format; + outTensorDescs.at(DIM0).dtype = inTensorDescs.at(DIM1).dtype; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM1).shape.dimNum; + + ATB_SPEED_LOG_DEBUG(opName_ << "FinalizeRoutingOperation infer shape origin inTensorDescs.at(DIM0).shape.dims[DIM0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0]); + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM1).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM1).shape.dims[DIM1]; + + ATB_SPEED_LOG_DEBUG(opName_ << "FinalizeRoutingOperation infer shape end"); + return 0; +} +uint32_t FinalizeRoutingOperation::GetInputNum() const +{ + return NUM7; +} + +uint32_t FinalizeRoutingOperation::GetOutputNum() const +{ + return DIM1; +} + +int FinalizeRoutingOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " FinalizeRoutingOperation start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnMoeFinalizeRoutingGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + aclnnVariantPack.aclInTensors.at(DIM1)->tensor, + aclnnVariantPack.aclInTensors.at(DIM2)->tensor, + aclnnVariantPack.aclInTensors.at(DIM3)->tensor, + aclnnVariantPack.aclInTensors.at(NUM4)->tensor, + aclnnVariantPack.aclInTensors.at(NUM5)->tensor, + aclnnVariantPack.aclInTensors.at(NUM6)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int FinalizeRoutingOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " FinalizeRoutingOperation start"); + + int ret = aclnnMoeFinalizeRouting( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " FinalizeRoutingOperation end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/finalize_routing_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/finalize_routing_operation.h new file mode 100644 index 00000000..ddf4892f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/finalize_routing_operation.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_FINALIZE_ROUTING_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_FINALIZE_ROUTING_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { + +/// This class defines an operator that performs scaling, sorting, and reducing by summation +/// on the input. +/// +/// This class makes uses of `aclnnMoeFinalizeRoutingGetWorkspaceSize` and `aclnnMoeFinalizeRouting` +/// from the AscendCL API. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// -------------------|---------------------|---------| +/// input1 | float16 or bfloat16 | [m*k,h] | +/// input2 | the same as input1 | [m,h] | +/// input3 | the same as input1 | [m,h] | +/// bias | the same as input1 | [e,h] | +/// scales | the same as input1 | [m,k] | +/// expandedRowIdx | int32 | [m*k] | +/// expandedExpertIdx | int32 | [m,k] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// --------------|--------------------|---------| +/// output | the same as input1 | [m,h] | + +/// Note: e is the total number of experts utilized by the model +/// k is the number of experts selected for each token +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_INPUT_TWO, +/// IN_INPUT_THREE +/// IN_BIAS, +/// IN_SCALES, +/// IN_EXPANDED_ROW_IDX, +/// IN_EXPANDED_EXPERT_IDX, +/// OUT +/// }; +/// +/// atb::Node &finalizeRoutingNode = opGraph.nodes.at(nodeId++); +/// atb_speed::common::MoefinalizeRoutingParam finalizeRoutingParam; +/// finalizeRoutingParam.topkNum = param.topk; +/// finalizeRoutingParam.expertNum = param.numOfExperts; +/// finalizeRoutingNode.operation = new atb_speed::common::FinalizeRoutingOperation("MoeFinalizeRoutingOperation", +/// initRoutingParam); +/// initRoutingNode.inTensorIds = {IN_INPUT = 0, +/// IN_INPUT_TWO, +/// IN_INPUT_THREE +/// IN_BIAS, +/// IN_SCALES, +/// IN_EXPANDED_ROW_IDX, +/// IN_EXPANDED_EXPERT_IDX}; +/// initRoutingNode.outTensorIds = {OUT}; +/// \endcode + +class FinalizeRoutingOperation : public AclNNOperation { +public: + explicit FinalizeRoutingOperation(const std::string &name); + ~FinalizeRoutingOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; +}; + +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_MOE_TOPK_SOFTMAX_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/gelu_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/gelu_operation.cpp new file mode 100644 index 00000000..f3f37702 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/gelu_operation.cpp @@ -0,0 +1,214 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "aclnnop/aclnn_gelu.h" +#include "aclnnop/aclnn_gelu_v2.h" +#include "gelu_operation.h" + + +namespace atb_speed::common { + + GeluOperation::GeluOperation( + const std::string &name, + atb_speed::common::AclNNGeluParam param + ) : AclNNOperation(name), param_(param) + { + this->opName_ = name; + this->param_ = param; + } + + GeluOperation::~GeluOperation() + { + ATB_SPEED_LOG_DEBUG("GeluOperation deconstruct"); + this->DestroyOperation(); + } + + /** + * + * @param[in] inTensorDesc: FA: [batchSize, seqLen, hiddenSize]; PA: [seqLen, hiddenSize] + * @param[in] outTensorDesc: FA: [batchSize, seqLen, hiddenSize]; PA: [seqLen, hiddenSize] + * @return atb::Status + */ + atb::Status GeluOperation::InferShape( + const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc + ) const + { + ATB_SPEED_LOG_DEBUG(opName_ << " InferShape start"); + outTensorDesc.at(0).format = inTensorDesc.at(0).format; + outTensorDesc.at(0).dtype = inTensorDesc.at(0).dtype; + outTensorDesc.at(0).shape.dimNum = inTensorDesc.at(0).shape.dimNum; + + ATB_SPEED_LOG_DEBUG("Check " << opName_ << " input dimNum=" << inTensorDesc.at(0).shape.dimNum); + for (uint64_t dim = 0; dim < inTensorDesc.at(0).shape.dimNum; ++dim) { + ATB_SPEED_LOG_DEBUG("input dim" << dim << " shape=" << inTensorDesc.at(0).shape.dims[dim]); + outTensorDesc.at(0).shape.dims[dim] = inTensorDesc.at(0).shape.dims[dim]; + } + + ATB_SPEED_LOG_DEBUG(opName_ << " InferShape end"); + return atb::NO_ERROR; + } + + uint32_t GeluOperation::GetInputNum() const + { + return NUM1; // inputTensorNum = 1 + } + + uint32_t GeluOperation::GetOutputNum() const + { + return NUM1; // outputTensorNum = 1 + } + + atb::Status GeluOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) + { + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret; + + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + std::stringstream ss; + ss << this->opName_ << " AclnnTensor CreateAclNNInTensorVariantPack fail, error: " << ret; + ATB_SPEED_LOG_ERROR(this->opName_ << " AclnnTensor CreateAclNNInTensorVariantPack fail, error: " << ret); + throw std::runtime_error(ss.str()); + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + std::stringstream ss; + ss << this->opName_ << " AclnnTensor CreateAclNNOutTensorVariantPack fail, error: " << ret; + ATB_SPEED_LOG_ERROR(this->opName_ << " AclnnTensor CreateAclNNOutTensorVariantPack fail, error: " << ret); + throw std::runtime_error(ss.str()); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; + } + + atb::Status GeluOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) + { + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.inTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; + } + + atb::Status GeluOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) + { + AclNNVariantPack &aclNnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclNnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclNnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.outTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " outTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclNnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; + } + + std::shared_ptr GeluOperation::CreateTensor(atb::Tensor atbTensor, size_t tensorIdx) + { + ATB_SPEED_LOG_DEBUG(opName_ << " CreateTensor start"); + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = static_cast(tensorIdx); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(atbTensor); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + squeezedAtbTensor.desc.shape.dims, + squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.desc.dtype, + aclnnTensor->strides.data(), + 0, + squeezedAtbTensor.desc.format, + squeezedAtbTensor.desc.shape.dims, + squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.deviceData); + ATB_SPEED_LOG_DEBUG(opName_ << " CreateTensor end"); + return aclnnTensor; + } + + int GeluOperation::SetAclNNWorkspaceExecutor() + { + ATB_SPEED_LOG_DEBUG( + opName_ << " SetAclNNWorkspaceExecutor start, geluApproximate: " << param_.geluApproximate + ); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + if (param_.geluApproximate == -1) { + int ret = aclnnGeluGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // self + aclnnVariantPack.aclOutTensors.at(0)->tensor, // out + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG( + opName_ << " SetAclNNWorkspaceExecutor end" + << ", ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor + ); + return ret; + } else { + int ret = aclnnGeluV2GetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // x + param_.geluApproximate, // approximate + aclnnVariantPack.aclOutTensors.at(0)->tensor, // y + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG( + opName_ << " SetAclNNWorkspaceExecutor end" + << ", ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor + ); + return ret; + } + } + + int GeluOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) + { + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + if (param_.geluApproximate == -1) { + int ret = aclnnGelu( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end, ret: " << ret); + return ret; + } else { + int ret = aclnnGeluV2( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end, ret: " << ret); + return ret; + } + } + +} // namespace atb_speed::common diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/gelu_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/gelu_operation.h new file mode 100644 index 00000000..acf2a918 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/gelu_operation.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_GELU_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_GELU_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" + + +namespace atb_speed::common { + + /// A struct defines `aclnnGelu` and `aclnnGelvV2` operation parameter. + struct AclNNGeluParam { + /// Indicates the gelu approximation algorithm to use. + /// + /// -1: use `aclnnGelu` operation, and use Tanh approximation approach to calculate Gelu. + /// 0: use `aclnnGelvV2` operation, and use Cumulative Distribution Function for Gaussian Distribution. + /// 1: use `aclnnGelvV2` operation, and use Tanh approximation approach to calculate Gelu. + int64_t geluApproximate = -1; + }; + + /// This class defines a matrix operation that applies the Gaussian Error Linear Units function. + /// + /// This class makes use of `aclnnGeluGetWorkspaceSize` and `aclnnGeluV2GetWorkspaceSize` from AscendCL API. + /// + /// Operation's Inputs: \n + /// | Name | Dtype | Shape | \n + /// |--------|--------------------------|-----------| \n + /// | x | float32/float16/bfloat16 | [-1,…,-1] | \n + /// + /// Operation's Outputs: \n + /// | Name | Dtype | Shape | \n + /// |--------|--------------------------|-----------| \n + /// | output | float32/float16/bfloat16 | [-1,…,-1] | \n + /// + /// Example: + /// \code + /// enum TensorIdx : uint32_t { + /// IN_INPUT = 0, + /// OUT, + /// }; + /// + /// atb::Node geluNode; + /// AclNNGeluParam geluParam; + /// geluParam.geluApproximate = 1; + /// geluNode.inTensorIds = { IN_INPUT }; + /// geluNode.outTensorIds = { OUT }; + /// geluNode.operation = new atb_speed::common::GeluOperation("geluNode", geluParam); + /// + /// atb::GraphParam opGraph; + /// opGraph.nodes.push_back(geluNode); + /// \endcode + class GeluOperation : public AclNNOperation { + public: + explicit GeluOperation(const std::string &name, AclNNGeluParam param); + ~GeluOperation() override; + atb::Status InferShape( + const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc + ) const override; + [[nodiscard]] uint32_t GetInputNum() const override; + [[nodiscard]] uint32_t GetOutputNum() const override; + + protected: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + atb::Status CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + virtual std::shared_ptr CreateTensor(atb::Tensor atbTensor, size_t tensorIdx); + + private: + AclNNGeluParam param_; + std::string opName_; + }; +} // namespace atb_speed::common + +#endif // ATB_SPEED_PLUGIN_ACLNN_GELU_OPERATION_H diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_operation.cpp new file mode 100644 index 00000000..f69f24fe --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_operation.cpp @@ -0,0 +1,335 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "grouped_matmul_operation.h" +#include +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_grouped_matmul.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "aclnnop/aclnn_grouped_matmul_v4.h" +#include "operations/aclnn/utils/utils.h" +#include "atb_speed/utils/check_util.h" + +namespace atb_speed { +namespace common { + +GroupedMatmulOperation::GroupedMatmulOperation( + const std::string &name, + AclNNGroupedMatmulParam param) : AclNNOperation(name), param_(param) { +} + +GroupedMatmulOperation::~GroupedMatmulOperation() { +} + +atb::Status GroupedMatmulOperation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "GroupedMatmulOperation infer shape start"); + ATB_SPEED_LOG_DEBUG(opName_ << "exports" << inTensorDescs.at(DIM2).shape.dims[DIM0]); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = param_.outDataType == ACL_BF16 || inTensorDescs.at(DIM0).dtype == ACL_BF16 ? \ + ACL_BF16 : ACL_FLOAT16; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + int nDim = param_.transposeB ? DIM1 : DIM2; + ATB_SPEED_LOG_DEBUG(opName_ << "GroupedMatmulOperation infer shape origin inTensorDescs.at(DIM0).shape.dims[DIM0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0]); + ATB_SPEED_LOG_DEBUG(opName_ << "GroupedMatmulOperation infer shape origin inTensorDescs.at(DIM1).shape.dims[nDim]" + << inTensorDescs.at(DIM1).shape.dims[nDim]); + + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM1).shape.dims[nDim]; + bool isW4 = param_.quantType == GmmQuantType::W4A16_CHANNEL or param_.quantType == GmmQuantType::W4A8_GROUP; + if (isW4 && !this->param_.transposeB) { + outTensorDescs.at(DIM0).shape.dims[DIM1] = \ + CheckIntMulOverFlow(outTensorDescs.at(DIM0).shape.dims[DIM1], 2); // 2: 最后一维shape * 2 + } + + ATB_SPEED_LOG_DEBUG(opName_ << "GroupedMatmulOperation infer shape end"); + return 0; +} + +uint32_t GroupedMatmulOperation::GetInputNum() const +{ + uint32_t inputNum = DIM3; + if (param_.hasBias) { + inputNum += DIM1; + } + if (param_.quantType != NONE) { + inputNum += DIM2; + } + return inputNum; +} + +uint32_t GroupedMatmulOperation::GetOutputNum() const +{ + return DIM1; +} + +atb::Dims GetWeightStorageShape(const atb::TensorDesc atbTensorDesc) +{ + atb::Dims storageTensorDims = atbTensorDesc.shape; // ND格式下,storageShape和originalShape一致 + if (atbTensorDesc.format == ACL_FORMAT_FRACTAL_NZ) { + // nz格式 + storageTensorDims.dimNum = 5; // 5: 5维 + // (group_size, k, n) => (group_size, k / 16, n / 16, 16, 16) + // (group_size, n, k) => (group_size, n / 16, k / 16, 16, 16) + storageTensorDims.dims[0] = atbTensorDesc.shape.dims[0]; + storageTensorDims.dims[1] = 1 + ((atbTensorDesc.shape.dims[1] - 1) / 16); // 1, 16:1: 维度, 16: padding大小 + storageTensorDims.dims[2] = 1 + ((atbTensorDesc.shape.dims[2] - 1) / 16); // 2, 16:1: 维度, 16: padding大小 + storageTensorDims.dims[3] = 16; // 3, 16:NZ格式要求 + storageTensorDims.dims[4] = 16; // 4, 16:NZ格式要求 + } + return storageTensorDims; +} + +atb::Dims GetWeightStorageW4Shape(const atb::TensorDesc atbTensorDesc) +{ + atb::Dims storageTensorDims = atbTensorDesc.shape; // ND格式下,storageShape和originalShape一致 + if (atbTensorDesc.format == ACL_FORMAT_FRACTAL_NZ) { + storageTensorDims.dimNum = 5; // 5: 5维 + // (group_size, k, n) => (group_size, k / 64, n / 16, 16, 32) + // (group_size, n, k) => (group_size, n / 64, k / 16, 16, 32) + storageTensorDims.dims[0] = atbTensorDesc.shape.dims[0]; + storageTensorDims.dims[1] = 1 + ((atbTensorDesc.shape.dims[DIM2] - 1) / 64); // 1, 16:1: 维度, 64: padding大小 + storageTensorDims.dims[2] = 1 + ((atbTensorDesc.shape.dims[1] - 1) / 16); // 2, 16:1: 维度, 16: padding大小 + storageTensorDims.dims[3] = 16; // 3, 16:NZ格式要求 + storageTensorDims.dims[4] = 32; // 4, 32:NZ格式要求 + } + return storageTensorDims; +} + +atb::Status GroupedMatmulOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + inputVectorOfTensor.resize(GetInputNum()); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + uint32_t inTensorCount = aclnnVariantPack.aclInTensors.size(); + for (size_t i = 0; i < inTensorCount; i++) { + std::shared_ptr aclnnTensor = std::make_shared(); + if (i == inTensorCount - 1) { + aclnnTensor->tensorIdx = 7; // 7 : for the last tensor + } else { + aclnnTensor->tensorListidx = i; + aclnnTensor->tensorIdx = 0; + } + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor squeezedAtbTensor = variantPack.inTensors.at(i); + + // int8 to int4 + bool isW4 = param_.quantType == GmmQuantType::W4A16_CHANNEL or param_.quantType == GmmQuantType::W4A8_GROUP; + if (i == 1 && isW4) { // 1: weight + squeezedAtbTensor.desc.dtype = ACL_INT4; + squeezedAtbTensor.desc.shape.dims[DIM2] = CheckIntMulOverFlow( + squeezedAtbTensor.desc.shape.dims[DIM2], 2); // 2: 最后一维shape * 2 + } + + // StorageShape + atb::Dims storageTensorDims; + if (i == DIM3 && param_.quantType == GmmQuantType::W4A8_GROUP) { + squeezedAtbTensor.desc.dtype = ACL_UINT64; + storageTensorDims = GetWeightStorageW4Shape(squeezedAtbTensor.desc); + } else { + // StorageShape + storageTensorDims = GetWeightStorageShape(squeezedAtbTensor.desc); + } + + // ViewShape and Stride + atb::Dims viewDims = squeezedAtbTensor.desc.shape; + if (squeezedAtbTensor.desc.shape.dimNum >= 3 && this->param_.transposeB) { // 3: 维度 + aclnnTensor->strides = GetTransposeTensorStride(viewDims); + viewDims.dims[0] = squeezedAtbTensor.desc.shape.dims[0]; + viewDims.dims[1] = squeezedAtbTensor.desc.shape.dims[2]; // 1, 2: 后两维转置 + viewDims.dims[2] = squeezedAtbTensor.desc.shape.dims[1]; // 1, 2: 后两维转置 + } else { + aclnnTensor->strides = GetCopyTensorStride(viewDims); + } + + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(viewDims, storageTensorDims, squeezedAtbTensor, aclnnTensor)); + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + + aclnnVariantPack.aclInTensorList.clear(); + + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size() - 1; i++) { + inputVectorOfTensor.at(i).clear(); + inputVectorOfTensor.at(i).push_back(aclnnVariantPack.aclInTensors.at(i)->tensor); + aclnnVariantPack.aclInTensorList.push_back(aclCreateTensorList( + inputVectorOfTensor.at(i).data(), inputVectorOfTensor.at(i).size())); + } + return atb::NO_ERROR; +} + +atb::Status GroupedMatmulOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorListidx = i; + aclnnTensor->tensorIdx = 0; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor squeezedAtbTensor = variantPack.outTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(squeezedAtbTensor.desc.shape, squeezedAtbTensor.desc.shape, + squeezedAtbTensor, aclnnTensor)); + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + yTensorVector.clear(); + yTensorVector.push_back(aclnnVariantPack.aclOutTensors.at(DIM0)->tensor); + aclnnVariantPack.aclOutTensorList.clear(); + aclnnVariantPack.aclOutTensorList.push_back(aclCreateTensorList(yTensorVector.data(), yTensorVector.size())); + return atb::NO_ERROR; +} + +int GroupedMatmulOperation::CreateW8A8(AclNNVariantPack &aclnnVariantPack) +{ + int ret = aclnnGroupedMatmulV4GetWorkspaceSize(aclnnVariantPack.aclInTensorList.at(DIM0), + aclnnVariantPack.aclInTensorList.at(DIM1), + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(DIM2) : nullptr, + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(DIM3) : + aclnnVariantPack.aclInTensorList.at(DIM2), + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(4) : // 4 : index of input tensor + aclnnVariantPack.aclInTensorList.at(DIM3), + nullptr, nullptr, nullptr, + param_.hasBias ? aclnnVariantPack.aclInTensors.at(5)->tensor : // 5 : index of input tensor + aclnnVariantPack.aclInTensors.at(4)->tensor, // 4 : index of input tensor + nullptr, nullptr, nullptr, + splitItem, groupType, groupListType, actType, + aclnnVariantPack.aclOutTensorList.at(DIM0), + nullptr, nullptr, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + return ret; +} + +int GroupedMatmulOperation::CreateW4A8(AclNNVariantPack &aclnnVariantPack) +{ + int ret = aclnnGroupedMatmulV4GetWorkspaceSize(aclnnVariantPack.aclInTensorList.at(DIM0), // x + aclnnVariantPack.aclInTensorList.at(DIM1), // weight + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(DIM2) : nullptr, // biasOptional + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(DIM3) : // scaleOptional + aclnnVariantPack.aclInTensorList.at(DIM2), + nullptr, nullptr, nullptr, // offsetOptional antiquantScaleOptional antiquantOffsetOptional + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(4) : // 4 : 外面的offset传进来当perTokenScaleOptional + aclnnVariantPack.aclInTensorList.at(DIM3), + param_.hasBias ? aclnnVariantPack.aclInTensors.at(5)->tensor : // 5 : index of input tensor groupListOptional + aclnnVariantPack.aclInTensors.at(4)->tensor, // 4 : index of input tensor + nullptr, nullptr, nullptr, + 3, 0, 1, 0, // splitItem, groupType, groupListType, actType, + aclnnVariantPack.aclOutTensorList.at(DIM0), + nullptr, nullptr, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + return ret; +} + +int GroupedMatmulOperation::CreateA16(AclNNVariantPack &aclnnVariantPack) +{ + int ret = aclnnGroupedMatmulV4GetWorkspaceSize(aclnnVariantPack.aclInTensorList.at(DIM0), + aclnnVariantPack.aclInTensorList.at(DIM1), + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(DIM2) : nullptr, + nullptr, nullptr, + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(DIM3) : + aclnnVariantPack.aclInTensorList.at(DIM2), + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(4) : // 4 : index of input tensor + aclnnVariantPack.aclInTensorList.at(DIM3), + nullptr, + param_.hasBias ? aclnnVariantPack.aclInTensors.at(5)->tensor : // 5 : index of input tensor + aclnnVariantPack.aclInTensors.at(4)->tensor, // 4 : index of input tensor + nullptr, nullptr, nullptr, + splitItem, groupType, groupListType, actType, + aclnnVariantPack.aclOutTensorList.at(DIM0), + nullptr, nullptr, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + return ret; +} + +int GroupedMatmulOperation::CreateW8A8Token(AclNNVariantPack &aclnnVariantPack) +{ + int ret = aclnnGroupedMatmulV4GetWorkspaceSize(aclnnVariantPack.aclInTensorList.at(DIM0), + aclnnVariantPack.aclInTensorList.at(DIM1), + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(DIM2) : nullptr, + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(DIM3) : + aclnnVariantPack.aclInTensorList.at(DIM2), + nullptr, nullptr, nullptr, + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(4) : // 5 : index of input tensor + aclnnVariantPack.aclInTensorList.at(3), // 4 : index of input tensor + param_.hasBias ? aclnnVariantPack.aclInTensors.at(5)->tensor : // 6 : index of input tensor + aclnnVariantPack.aclInTensors.at(4)->tensor, // 5 : index of input tensor + nullptr, nullptr, nullptr, + splitItem, groupType, groupListType, actType, + aclnnVariantPack.aclOutTensorList.at(DIM0), + nullptr, nullptr, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + return ret; +} + +int GroupedMatmulOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = 0; + if (param_.quantType == GmmQuantType::NONE) { + ret = aclnnGroupedMatmulV4GetWorkspaceSize(aclnnVariantPack.aclInTensorList.at(DIM0), + aclnnVariantPack.aclInTensorList.at(DIM1), + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(DIM2) : nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, + param_.hasBias ? aclnnVariantPack.aclInTensors.at(DIM3)->tensor : + aclnnVariantPack.aclInTensors.at(DIM2)->tensor, + nullptr, nullptr, nullptr, + splitItem, groupType, groupListType, actType, + aclnnVariantPack.aclOutTensorList.at(DIM0), + nullptr, nullptr, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + } else if (param_.quantType == GmmQuantType::W8A8_CHANNEL) { + ret = CreateW8A8(aclnnVariantPack); + } else if (param_.quantType == GmmQuantType::W8A16_CHANNEL || param_.quantType == GmmQuantType::W4A16_CHANNEL) { + ret = CreateA16(aclnnVariantPack); + } else if (param_.quantType == GmmQuantType::W4A8_GROUP) { + ret = CreateW4A8(aclnnVariantPack); + } else { + ret = CreateW8A8Token(aclnnVariantPack); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int GroupedMatmulOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnGroupedMatmul start"); + int ret = aclnnGroupedMatmulV4( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnGroupedMatmul end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_operation.h new file mode 100644 index 00000000..e752463f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_operation.h @@ -0,0 +1,120 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_GROUPED_MATMUL_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_GROUPED_MATMUL_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "aclnnop/aclnn_grouped_matmul_v4.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { +enum GmmQuantType : int { + NONE = 0, + W8A8_CHANNEL, + W8A16_CHANNEL, + W4A16_CHANNEL, + W8A8_TOKEN, + W4A8_GROUP +}; + +struct AclNNGroupedMatmulParam { + /// A flag indicating whether the second input matrix needs to be transposed + bool transposeB = false; + /// The quantization type of the operation + int quantType = NONE; + /// A flag indicating whether the matmul operation includes a bias tensor + bool hasBias = false; + /// The data type of the output of the operation + aclDataType outDataType = ACL_FLOAT16; +}; + +/// This class defines an operator that consists of a group of matrix multiplications. +/// Meanwhile, this operator supports different quantization types. +/// +/// This class makes uses of `aclnnGroupedMatmulV4GetWorkspaceSize` and `aclnnGroupedMatmulV4` +/// from the AscendCL API. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// ------------------------|-------|-------| +/// input | * | [m,k] | +/// weight | * | [e,n,k] if `transposeB` is true; otherwise, [e,k,n] | +/// biasOptional | * | [e,n] | +/// scaleOptional | * | [e,n] | +/// offsetOptional | * | [e,n] | +/// antiquantScaleOptional | * | [e,n] | +/// antiquantOffsetOptional | * | [e,n] | +/// groupList | int64 | [e] | +/// * Note: the data type of inputs are specific to the quantization type/technique chosen for the model +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// -------|---------------------|-------| +/// output | float16 or bfloat16 | [m,n] | +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_WEIGHT, +/// IN_GROUP_LIST, +/// OUT, +/// }; +/// +/// atb::Node &gmmNode = opGraph.nodes.at(nodeId++); +/// atb_speed::common::AclNNGroupedMatmulParam gmmParam; +/// gmmParam.quantType = gmmQuantType; +/// gmmParam.outDataType = param.outDataType; +/// gmmParam.transposeB = param.transposeB; +/// gmmNode.operation = new atb_speed::common::GroupedMatmulOperation("gmmNode", gmmParam); +/// gmmNode.inTensorIds = {IN_INPUT, IN_WEIGHT, IN_GROUP_LIST}; +/// gmmNode.outTensorIds = {OUT}; +/// \endcode + +class GroupedMatmulOperation : public AclNNOperation { +public: + explicit GroupedMatmulOperation(const std::string &name, AclNNGroupedMatmulParam param); + ~GroupedMatmulOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + + int CreateW8A8(AclNNVariantPack &aclnnVariantPack); + int CreateA16(AclNNVariantPack &aclnnVariantPack); + int CreateW8A8Token(AclNNVariantPack &aclnnVariantPack); + int CreateW4A8(AclNNVariantPack &aclnnVariantPack); + + std::vector yTensorVector; + std::vector> inputVectorOfTensor; + std::vector weightTensorVector; + int64_t splitItem = 2; + int64_t groupType = 0; + int64_t groupListType = 0; // 0 : GMMActType::GMM_ACT_TYPE_NONE + int64_t actType = 0; + AclNNGroupedMatmulParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_GROUPED_MATMUL_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_swiglu_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_swiglu_operation.cpp new file mode 100644 index 00000000..f474a757 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_swiglu_operation.cpp @@ -0,0 +1,173 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "grouped_matmul_swiglu_operation.h" +#include +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "aclnnop/aclnn_grouped_matmul_swiglu_quant.h" +#include "operations/aclnn/utils/utils.h" + +namespace atb_speed { +namespace common { + +GroupedMatmulSwigluOperation::GroupedMatmulSwigluOperation( + const std::string &name, + AclNNGroupedSwigluMatmulParam param) : AclNNOperation(name), param_(param) { +} + +GroupedMatmulSwigluOperation::~GroupedMatmulSwigluOperation() { +} + +atb::Status GroupedMatmulSwigluOperation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "GroupedMatmulSwigluOperation infer shape start"); + ATB_SPEED_LOG_DEBUG(opName_ << "exports" << inTensorDescs.at(DIM2).shape.dims[DIM0]); + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = inTensorDescs.at(DIM0).dtype; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + int nDim = DIM1; + if (inTensorDescs.at(DIM1).shape.dims[1] == inTensorDescs.at(DIM0).shape.dims[1]) { + nDim = DIM2; + } + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM1).shape.dims[nDim] / 2; // 2: swiglu quant + outTensorDescs.at(DIM1).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM1).dtype = ACL_FLOAT; + outTensorDescs.at(DIM1).shape.dimNum = 1; + outTensorDescs.at(DIM1).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + + ATB_SPEED_LOG_DEBUG(opName_ << "GroupedMatmulSwigluOperation infer shape end"); + return 0; +} + +uint32_t GroupedMatmulSwigluOperation::GetInputNum() const +{ + return INPUT_NUM; +} + +uint32_t GroupedMatmulSwigluOperation::GetOutputNum() const +{ + return OUTPUT_NUM; +} + +atb::Dims SetWeightStorageShape(const atb::TensorDesc& atbTensorDesc) +{ + atb::Dims storageTensorDims = atbTensorDesc.shape; // ND格式下,storageShape和originalShape一致 + if (atbTensorDesc.format == ACL_FORMAT_FRACTAL_NZ) { + // nz格式 + storageTensorDims.dimNum = 5; // 5: 5维 + // (group_size, k, n) => (group_size, n / 16, k / 32, 16, 32) + storageTensorDims.dims[0] = atbTensorDesc.shape.dims[0]; + storageTensorDims.dims[2] = 1 + ((atbTensorDesc.shape.dims[1] - 1) / 16); // 1, 16:2: 维度, 16: padding大小 + storageTensorDims.dims[1] = 1 + ((atbTensorDesc.shape.dims[2] - 1) / 32); // 2, 32:2: 维度, 32: padding大小 + storageTensorDims.dims[3] = 16; // 3, 16:NZ格式要求 + storageTensorDims.dims[4] = 32; // 4, 32:NZ格式要求 + } + return storageTensorDims; +} + +atb::Status GroupedMatmulSwigluOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(INPUT_NUM); + const int aclnnTensorIndex[INPUT_NUM] = {0, 1, 4, 5, 6}; // valid input index + for (size_t i = 0; i < INPUT_NUM; i++) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = aclnnTensorIndex[i]; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor squeezedAtbTensor = variantPack.inTensors.at(i); + + // StorageShape + atb::Dims storageTensorDims = SetWeightStorageShape(squeezedAtbTensor.desc); + + // ViewShape and Stride + atb::Dims viewDims = squeezedAtbTensor.desc.shape; + if (squeezedAtbTensor.desc.shape.dimNum >= 3 && this->param_.transposeB) { // 3: 维度 + aclnnTensor->strides = GetTransposeTensorStride(viewDims); + viewDims.dims[0] = squeezedAtbTensor.desc.shape.dims[0]; + viewDims.dims[1] = squeezedAtbTensor.desc.shape.dims[2]; // 1, 2: 后两维转置 + viewDims.dims[2] = squeezedAtbTensor.desc.shape.dims[1]; // 1, 2: 后两维转置 + } else { + aclnnTensor->strides = GetCopyTensorStride(viewDims); + } + + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(storageTensorDims, storageTensorDims, + squeezedAtbTensor, aclnnTensor)); + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +atb::Status GroupedMatmulSwigluOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(OUTPUT_NUM); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor squeezedAtbTensor = variantPack.outTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(squeezedAtbTensor.desc.shape, squeezedAtbTensor.desc.shape, + squeezedAtbTensor, aclnnTensor)); + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int GroupedMatmulSwigluOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnGroupedMatmulSwigluQuantGetWorkspaceSize(aclnnVariantPack.aclInTensors.at(DIM0)->tensor, // 0: x + aclnnVariantPack.aclInTensors.at(DIM1)->tensor, // 1: weight + nullptr, // bias + nullptr, // offset + aclnnVariantPack.aclInTensors.at(DIM2)->tensor, // 2: weight_scale + aclnnVariantPack.aclInTensors.at(DIM3)->tensor, // 3: x_scale + aclnnVariantPack.aclInTensors.at(4)->tensor, // 4: group_list IDX + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, // out0: output + aclnnVariantPack.aclOutTensors.at(DIM1)->tensor, // out1: output_scale + nullptr, // out2: output_offset + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int GroupedMatmulSwigluOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnGroupedMatmul start"); + int ret = aclnnGroupedMatmulSwigluQuant( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnGroupedMatmul end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_swiglu_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_swiglu_operation.h new file mode 100644 index 00000000..3acfb716 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/grouped_matmul_swiglu_operation.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_GROUPED_MATMUL_SWIGLU_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_GROUPED_MATMUL_SWIGLU_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { +enum GmmQuantSwigluType : int { + NONE1 = 0, + W8A8_CHANNEL1, + W8A16_CHANNEL1, + W8A8_TOKEN1 +}; + +struct AclNNGroupedSwigluMatmulParam { + bool transposeB = false; /// A flag indicating wheter the second input matrix needs to be transposed + int quantType = 0; /// The quantization type of the operation + bool hasBias = false; /// A flag indicating whether the matmul operation includes a bias tensor + aclDataType outDataType = ACL_FLOAT16; /// The data type of the outpuot of the oepration +}; + +/// This calss defines an operator that consists of a group of matrix multiplications. +/// Meanwhile, this operator supports different quantization types. +/// +/// This class makes uses of `aclnnGroupedMatmulV4GetWorkspaceSize` and `aclnnGroupedMatmulV4` +/// from the AscendCL API. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// ------------------------|-------|-------| +/// input | * | [m,k] | +/// weight | * | [e,n,k] if `transposeB` is true; otherwise, [e,k,n] | +/// PerChannelscale | * | [e,k] | +/// PerTokenscale | * | [m] | +/// groupList | int64 | [e] | +/// * Note: the data type of inputs are speccfic to the quantization type/technique chosen for the model +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// -------|---------------------|-------| +/// quant_output | int8 | [m,n/2] | +/// quant_scale_output | float | [m] | +/// +/// Example: +/// \code +/// enum InTensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_WEIGHT, +/// IN_SCALE_EXPERT, +/// IN_DYNAMIC_SCALE, +/// IN_GROUP_LIST, +/// }; +/// +/// enum OutTensorIdx : uint32_t { +/// QUANT_OUT = 0, +/// QUANT_SCALE +/// }; +/// +/// atb::Node &gmmNode = opGraph.nodes.at(nodeId++); +/// atb_speed::common::AclNNGroupedMatmulParam gmmParam; +/// gmmParam.quantType = gmmQuantType; +/// gmmParam.outDataType = param.outDataType; +/// gmmParam.transposeB = param.transposeB; +/// gmmNode.operation = new atb_speed::common::GroupedMatmulSwigluOperation("gmmNode", gmmParam); +/// gmmNode.inTensorIds = {IN_INPUT, IN_WEIGHT, IN_SCALE_EXPERT, IN_DYNAMIC_SCALE, IN_GROUP_LIST}; +/// gmmNode.outTensorIds = {QUANT_OUT,QUANT_SCALE}; +/// \endcode + +class GroupedMatmulSwigluOperation : public AclNNOperation { +public: + explicit GroupedMatmulSwigluOperation(const std::string &name, AclNNGroupedSwigluMatmulParam param); + ~GroupedMatmulSwigluOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + AclNNGroupedSwigluMatmulParam param_; + static constexpr uint32_t INPUT_NUM = 5U; + static constexpr uint32_t OUTPUT_NUM = 2U; +}; +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_GROUPED_MATMUL_OPERATION_H diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/index_select_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/index_select_operation.cpp new file mode 100644 index 00000000..9bdc21bd --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/index_select_operation.cpp @@ -0,0 +1,117 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "aclnnop/aclnn_index_select.h" +#include "index_select_operation.h" + + +namespace atb_speed::common { + IndexSelectOperation::IndexSelectOperation( + const std::string &name, + atb_speed::common::IndexSelectParam param + ) : AclNNOperation(name), param_(param) + { + ATB_SPEED_LOG_DEBUG("IndexSelectOperation construct"); + this->opName_ = name; + } + + IndexSelectOperation::~IndexSelectOperation() + { + ATB_SPEED_LOG_DEBUG("IndexSelectOperation deconstruct"); + this->DestroyOperation(); + } + + uint32_t IndexSelectOperation::GetInputNum() const + { + return NUM2; // inputTensorNum = 2 + } + + uint32_t IndexSelectOperation::GetOutputNum() const + { + return NUM1; // outputTensorNum = 1 + } + + /** + * + * @param[in] inTensorDescs: [self, indices] + * @param[in] outTensorDescs: out + * @return atb::Status + */ + atb::Status IndexSelectOperation::InferShape( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs + ) const + { + ATB_SPEED_LOG_DEBUG(opName_ << " InferShape start"); + outTensorDescs.at(0) = inTensorDescs.at(0); + + if (inTensorDescs.at(0).shape.dimNum > 8) { // 8: tensor max dim num + ATB_SPEED_LOG_ERROR(opName_ << " [input0 dimNum should <= 8] CHECK input0 dimNum = " + << inTensorDescs.at(0).shape.dimNum); + } + + int64_t selfDimNum = static_cast(inTensorDescs.at(0).shape.dimNum); + if ((param_.dim >= selfDimNum) || (param_.dim < -selfDimNum)) { + ATB_SPEED_LOG_ERROR(opName_ << " [param dim should in [-input0 dimNum, input0 dimNum)) " + << "CHECK param dim = " << param_.dim << ", input0 dimNum = " << selfDimNum); + } + + if (inTensorDescs.at(1).shape.dimNum != DIM1) { + ATB_SPEED_LOG_ERROR(opName_ << " [input1 dimNum should == 1] CHECK input1 dimNum = " + << inTensorDescs.at(0).shape.dimNum); + } + + outTensorDescs.at(0).shape.dims[param_.dim] = inTensorDescs.at(1).shape.dims[DIM0]; + + ATB_SPEED_LOG_DEBUG(opName_ << " InferShape end"); + return atb::NO_ERROR; + } + + int IndexSelectOperation::SetAclNNWorkspaceExecutor() + { + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnIndexSelectGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // self + param_.dim, // dim + aclnnVariantPack.aclInTensors.at(1)->tensor, // index + aclnnVariantPack.aclOutTensors.at(0)->tensor, // out + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end" + << ", ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor); + return ret; + } + + int IndexSelectOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) + { + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnIndexSelect( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end" + << ", ret: " << ret); + return ret; + } + +} // namespace atb_speed::common diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/index_select_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/index_select_operation.h new file mode 100644 index 00000000..ee8a9cca --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/index_select_operation.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_INDEX_SELECT_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_INDEX_SELECT_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" + + +namespace atb_speed::common { + +/// A struct defines `IndexSelect`'s parameter. +struct IndexSelectParam { + /// A flag indicating the specified dimension of the input tensor, + /// the range is [-input.dim(), input.dim() - 1]. + int64_t dim = 0; +}; + +/// This class defines a matrix operation that supports +/// extract elements from the specified dimension dim of the input Tensor according to the index sequence numbers +/// and save them to the out Tensor. +/// +/// This class makes use of `aclnnIndexSelectGetWorkspaceSize` and `aclnnIndexSelect` from the AscendCL API. +/// +/// Operation's Inputs: +/// Name | Dtype | Shape | +/// ----------------|---------|-------| +/// input | float32, float16, bfloat16 | The dimension is not greater than 8 | +/// index | int32, int64 | [n] | +/// +/// Operations's Outputs: +/// Name | Dtype | Shape | +/// -------|---------------------|-------| +/// output | same as input | The dimension is the same as input. The length of the dim dimension is equal to the index.| +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_INDEX, +/// OUT, +/// }; +/// +/// atb::Node indexSelectNode; +/// IndexSelectParam indexSelectParam; +/// indexSelectParam.dim = 0; +/// indexSelectNode.inTensorIds = {IN_INPUT, IN_INDEX}; +/// indexSelectNode.outTensorIds = {OUT}; +/// indexSelectNode.operation = new atb_speed::common::IndexSelectOperation("IndexSelectNode", IndexSelectParam); +/// +/// // Add the operation node to the graph as required +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(indexSelectNode); +/// \endcode + +class IndexSelectOperation : public AclNNOperation { +public: + explicit IndexSelectOperation(const std::string &name, IndexSelectParam param); + ~IndexSelectOperation() override; + atb::Status InferShape( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs + ) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + IndexSelectParam param_; + std::string opName_; +}; +} // namespace atb_speed::common + +#endif // ATB_SPEED_PLUGIN_ACLNN_INDEX_SELECT_OPERATION_H diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/indexput_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/indexput_operation.cpp new file mode 100644 index 00000000..c22722ab --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/indexput_operation.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "acl/acl.h" +#include "atb_speed/log.h" + +#include "aclnnop/aclnn_index_put_impl.h" + +#include "operations/aclnn/utils/utils.h" +#include "indexput_operation.h" + +namespace atb_speed { +namespace common { +IndexputOperation::IndexputOperation(const std::string &name, AclNNIndexputParam param) + : AclNNOperation(name), param_(param) +{ + ATB_SPEED_LOG_DEBUG("IndexputOperation, param: " << param_.ToString()); +} + +IndexputOperation::~IndexputOperation() +{ + ATB_SPEED_LOG_DEBUG("~IndexputOperation"); + this->DestroyOperation(); +} + +uint32_t IndexputOperation::GetInputNum() const { return NUM3; } + +uint32_t IndexputOperation::GetOutputNum() const { return NUM1; } + +atb::Status IndexputOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + outTensorDescs.at(0) = inTensorDescs.at(0); + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return atb::NO_ERROR; +} + +atb::Status IndexputOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + uint32_t inputNum = this->GetInputNum(); + aclnnVariantPack.aclInTensors.resize(inputNum); + for (uint32_t i = 0; i < inputNum; ++i) { + aclnnVariantPack.aclInTensors[i] = CreateTensor(variantPack.inTensors.at(i), static_cast(i)); + if (i == 1) { + aclnnVariantPack.aclInTensors.at(i)->tensorListidx = 0; + aclnnVariantPack.aclInTensors.at(i)->tensorIdx = 0; + } + } + + vectorList.clear(); + vectorList.push_back(aclnnVariantPack.aclInTensors.at(1)->tensor); + aclnnVariantPack.aclInTensorList.clear(); + aclnnVariantPack.aclInTensorList.push_back(aclCreateTensorList(vectorList.data(), vectorList.size())); + + aclnnVariantPack.aclOutTensors.clear(); + aclnnVariantPack.aclOutTensors.push_back(CreateTensor(variantPack.outTensors.at(0), 0)); + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return 0; +} + +int IndexputOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + + int ret = aclnnIndexPutImplGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, aclnnVariantPack.aclInTensorList.at(0), + aclnnVariantPack.aclInTensors.at(2)->tensor, param_.accumulate, param_.unsafe, + &this->aclnnOpCache_->workspaceSize, &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " end, ret:" << ret << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end"); + return ret; +} + +int IndexputOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = + aclnnIndexPutImpl(workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/indexput_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/indexput_operation.h new file mode 100644 index 00000000..73d51055 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/indexput_operation.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_INDEXPUT_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_INDEXPUT_OPERATION_H +#include +#include +#include + +#include "operations/aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { +/// A struct defines `Indexput`'s parameter. +struct AclNNIndexputParam { + /// A flag indicating whether accumulation or update. + bool accumulate = false; + /// Check whether the index is within the valid range flag. + bool unsafe = true; + + std::string ToString() const + { + std::ostringstream oss; + oss << "AclNNIndexputParam {" << std::endl; + oss << " accumulate: " << accumulate << std::endl; + oss << " unsafe: " << unsafe << std::endl; + oss << "}"; + return oss.str(); + } +}; + +/// This class defines a matrix operation that supports +/// update or accumulate the data at the corresponding coordinates of the input x +/// with the input value according to the indices. +/// +/// This class makes use of `aclnnIndexPutImplGetWorkspaceSize` and `aclnnIndexPutImpl` from the AscendCL API. +/// +/// Operation's Inputs: +/// Name | Dtype | Shape | +/// ----------------|---------|-------| +/// input | float32, float16, bfloat16 | The dimension is not greater than 8 | +/// indices | int32, int64, bool | [n] | +/// values | same as input | The dimension is the same as input. The first dimension is equal to the indices. | +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_INDICES, +/// IN_VALUES, +/// }; +/// +/// atb::Node indexPutNode; +/// AclNNIndexputParam indexPutParam; +/// indexPutParam.dim = 0; +/// indexPutNode.inTensorIds = {IN_INPUT, IN_INDICES, IN_VALUES}; +/// indexPutNode.outTensorIds = {IN_INPUT}; +/// indexPutNode.operation = new atb_speed::common::IndexSelectOperation("IndexPutNode", indexPutParam); +/// +/// // Add the operation node to the graph as required +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(indexPutNode); +/// \endcode + +class IndexputOperation : public AclNNOperation { +public: + explicit IndexputOperation(const std::string &name, AclNNIndexputParam param); + ~IndexputOperation() override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + +private: + atb::Status CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + std::vector vectorList; + AclNNIndexputParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/inplace_nan_to_num_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/inplace_nan_to_num_operation.cpp new file mode 100644 index 00000000..32712a26 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/inplace_nan_to_num_operation.cpp @@ -0,0 +1,125 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "inplace_nan_to_num_operation.h" +#include "acl/acl.h" +#include "aclnnop/aclnn_nan_to_num.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" + +namespace atb_speed::common { + +InplaceNanToNumOperation::InplaceNanToNumOperation( + const std::string &name, atb_speed::common::AclNNNanToNumParam param) : AclNNOperation(name), param_(param) +{ + this->opName_ = name; + this->param_ = param; +} + +InplaceNanToNumOperation::~InplaceNanToNumOperation() +{ + ATB_SPEED_LOG_DEBUG("InplaceNanToNumOperation deconstruct"); + this->DestroyOperation(); +} + +/** + * + * @param[in] inTensorDesc: FA: [batchSize, seqLen, hiddenSize]; PA: [seqLen, hiddenSize] + * @param[in] outTensorDesc: FA: [batchSize, seqLen, hiddenSize]; PA: [seqLen, hiddenSize] + * @return atb::Status +*/ +atb::Status InplaceNanToNumOperation::InferShape( + const atb::SVector &inTensorDesc, atb::SVector &outTensorDesc) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "InplaceNanToNumOperation infer shape start"); + outTensorDesc.at(0).format = inTensorDesc.at(0).format; + outTensorDesc.at(0).dtype = inTensorDesc.at(0).dtype; + outTensorDesc.at(0).shape.dimNum = inTensorDesc.at(0).shape.dimNum; + for (uint64_t i = 0; i < inTensorDesc.at(0).shape.dimNum; ++i) { + outTensorDesc.at(0).shape.dims[i] = inTensorDesc.at(0).shape.dims[i]; + } + + ATB_SPEED_LOG_DEBUG(opName_ << "InplaceNanToNumOperation infer shape end" + << " format: " << inTensorDesc.at(0).format << " dimNum: " << inTensorDesc.at(0).shape.dimNum + << " dims: " << inTensorDesc.at(0).shape.dims[0]); + return atb::NO_ERROR; +} + +uint32_t InplaceNanToNumOperation::GetInputNum() const +{ + return NUM1; // inputTensorNum = 1 +} + +uint32_t InplaceNanToNumOperation::GetOutputNum() const +{ + return NUM1; // outputTensorNum = 1 +} + + +atb::Status InplaceNanToNumOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.inTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int InplaceNanToNumOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG( + opName_ << " SetAclNNWorkspaceExecutor start, nanValue: " << + param_.nanValue << " posInfValue: " << + param_.posInfValue << " negInfValue: " << + param_.negInfValue + ); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnInplaceNanToNumGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // self + param_.nanValue, + param_.posInfValue, + param_.negInfValue, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG( + opName_ << " SetAclNNWorkspaceExecutor end" + << ", ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor + ); + return ret; +} + +int InplaceNanToNumOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnInplaceNanToNum( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end, ret: " << ret); + return ret; +} + +} // namespace atb_speed::common + diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/inplace_nan_to_num_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/inplace_nan_to_num_operation.h new file mode 100644 index 00000000..efef89a5 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/inplace_nan_to_num_operation.h @@ -0,0 +1,87 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_INPLACE_NAN_TO_NUM_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_INPLACE_NAN_TO_NUM_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" + + +namespace atb_speed::common { + + +struct AclNNNanToNumParam { + /// nanValue: Input parameter that replaces NaN values in tensor elements. The data type supports FLOAT. + /// posInfValue: Input parameter that replaces positive infinity values in tensor elements. + /// The data type supports FLOAT. + /// negInfValue: Input parameter that replaces negative infinity values in tensor elements. + /// The data type supports FLOAT. + float nanValue = 0.0; + float posInfValue = 65504.0; + float negInfValue = -65504.0; +}; + +/// Replace NaN, positive infinity, and negative infinity values in the input with the +/// values specified by nan, posinf, and neginf, respectively. +/// +/// Operation's Inputs: \n +/// | Name | Dtype | Shape | \n +/// |--------|--------------------------|-----------| \n +/// | x | FLOAT16、FLOAT32、INT8、INT16、INT32、INT64、UINT8、BOOL、BFLOAT16 | [-1,…,-1] | \n +/// +/// Operation's Outputs: it is inplace replace.\n +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// }; +/// +/// atb::GraphParam opGraph; +/// atb::Node nanToNumNode; +/// atb_speed::common::AclNNNanToNumParam NanToNumParam; +/// NanToNumParam.posInfValue = 50000.0; +/// NanToNumParam.negInfValue = -50000.0; +/// nanToNumNode.operation = new atb_speed::common::InplaceNanToNumOperation("nanToNumNode", NanToNumParam); +/// nanToNumNode.inTensorIds = { IN_INPUT }; +/// nanToNumNode.outTensorIds = { IN_INPUT }; +/// opGraph.nodes.push_back(nanToNumNode); +/// +/// \endcode +class InplaceNanToNumOperation : public AclNNOperation { +public: + explicit InplaceNanToNumOperation(const std::string &name, AclNNNanToNumParam param); + ~InplaceNanToNumOperation() override; + atb::Status InferShape( + const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc + ) const override; + [[nodiscard]] uint32_t GetInputNum() const override; + [[nodiscard]] uint32_t GetOutputNum() const override; + +protected: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + +private: + AclNNNanToNumParam param_; + std::string opName_; +}; +} // namespace atb_speed::common + +#endif // ATB_SPEED_PLUGIN_ACLNN_INPLACE_NAN_TO_NUM_OPERATION_H + diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/inplacemasked_filltensor_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/inplacemasked_filltensor_operation.cpp new file mode 100644 index 00000000..04d92ae6 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/inplacemasked_filltensor_operation.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "inplacemasked_filltensor_operation.h" +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "aclnnop/aclnn_masked_fill_scalar.h" + +namespace atb_speed::common { + +InplaceMaskedFillTensorOperation::InplaceMaskedFillTensorOperation( + const std::string &name, + atb_speed::common::InplaceMaskedFillTensorParam param +) : AclNNOperation(name), param_(param) +{ + this->opName_ = name; + this->param_ = param; +} + +InplaceMaskedFillTensorOperation::~InplaceMaskedFillTensorOperation() +{ + ATB_SPEED_LOG_DEBUG("InplaceMaskedFillTensorOperation deconstruct"); + this->DestroyOperation(); +} + +atb::Status InplaceMaskedFillTensorOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "InplaceMaskedFillTensorOperation infer shape start"); + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum; + for (uint64_t i = 0; i < inTensorDescs.at(0).shape.dimNum; ++i) { + outTensorDescs.at(0).shape.dims[i] = inTensorDescs.at(0).shape.dims[i]; + } + + ATB_SPEED_LOG_DEBUG(opName_ << "InplaceMaskedFillTensorOperation infer shape end" + << " format: " << inTensorDescs.at(0).format << " dimNum: " << inTensorDescs.at(0).shape.dimNum + << " dims: " << inTensorDescs.at(0).shape.dims[0]); + return atb::NO_ERROR; +} + + +uint32_t InplaceMaskedFillTensorOperation::GetInputNum() const +{ + return DIM2; +} + +uint32_t InplaceMaskedFillTensorOperation::GetOutputNum() const +{ + return DIM1; +} + +atb::Status InplaceMaskedFillTensorOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor squeezedAtbTensor = variantPack.inTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + if (i == 1) { + squeezedAtbTensor.desc.dtype = aclDataType::ACL_BOOL; + } + CallAclCreateTensor(squeezedAtbTensor.desc.shape, squeezedAtbTensor.desc.shape, + squeezedAtbTensor, aclnnTensor); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " OutTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int InplaceMaskedFillTensorOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclScalar* value = aclCreateScalar(¶m_.value, param_.outDataType); + int ret = aclnnInplaceMaskedFillScalarGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, // input + aclnnVariantPack.aclInTensors.at(DIM1)->tensor, // input + value, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end" + << ", ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int InplaceMaskedFillTensorOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnInplaceMaskedFillScalar( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end" + << ", ret: " << ret); + return ret; +} +} // namespace atb_speed::common \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/inplacemasked_filltensor_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/inplacemasked_filltensor_operation.h new file mode 100644 index 00000000..96884402 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/inplacemasked_filltensor_operation.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_MASKEDFILL_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_MASKEDFILL_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" + +namespace atb_speed::common { +struct InplaceMaskedFillTensorParam { + float value = 0; + aclDataType outDataType = ACL_FLOAT16; +}; + +/// This class defines an operator that replaces the value in the tensor with another specified value. +/// +/// This class makes uses of `aclnnInplaceMaskedFillScalarGetWorkspaceSize` and `aclnnInplaceMaskedFillScalar` +/// form the AscendCL API. +/// +/// Inputs to the operator +/// Name | Dtype | Shape | +/// -------------|---------------------|-------| +/// input | float16 or bfloat16 | [m] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// -------------|---------------------|-------| +/// output | float16 or bfloat16 | [m] | +/// Note: The output is a placeholder that wouldn't be written during executing. +/// +/// Example: +/// \code +/// enum InTensorIdx : uint32_t {INPUT = 0}; +/// +/// enum OutTensorIdx : uint32_t {OUT = 0}; +/// +/// atb::Node &maskedFillNode = opGraph.nodes.at(nodeId++); +/// atb_speed::common::InplaceMaskedFillTensorParam fillParam; +/// fillParam.value = param.fillValue; +/// fillParam.outDataType = param.outDataType; +/// maskedFillNode.operation = new atb_speed::common::InplaceMaskedFillTensorOperation("MaskedFill", fillParam); +/// maskedFillNode.inTensorIds = {INPUT}; +/// maskedFillNode.outTensorIds = {OUTPUT}; +/// \endcode + +class InplaceMaskedFillTensorOperation : public AclNNOperation { +public: + explicit InplaceMaskedFillTensorOperation(const std::string &name, InplaceMaskedFillTensorParam param); + ~InplaceMaskedFillTensorOperation() override; + atb::Status InferShape( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs + ) const override; + [[nodiscard]] uint32_t GetInputNum() const override; + [[nodiscard]] uint32_t GetOutputNum() const override; + +protected: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + +private: + InplaceMaskedFillTensorParam param_; + std::string opName_; +}; +} // namespace atb_speed::common + +#endif // ATB_SPEED_PLUGIN_ACLNN_MASKEDFILL_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/layer_norm_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/layer_norm_operation.cpp new file mode 100644 index 00000000..5fa3a6c4 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/layer_norm_operation.cpp @@ -0,0 +1,216 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "aclnnop/aclnn_layer_norm.h" +#include "layer_norm_operation.h" + + +namespace atb_speed::common { + + LayerNormOperation::LayerNormOperation( + const std::string &name, + atb_speed::common::AclNNLayerNormParam param + ) : AclNNOperation(name), param_(param) + { + this->opName_ = name; + this->param_ = param; + } + + LayerNormOperation::~LayerNormOperation() + { + ATB_SPEED_LOG_DEBUG("LayerNormOperation deconstruct"); + this->DestroyOperation(); + } + + /** + * + * @param[in] inTensorDesc: FA: [batchSize, seqLen, hiddenSize]; PA: [seqLen, hiddenSize] + * @param[in] outTensorDesc: FA: [batchSize, seqLen, hiddenSize]; PA: [seqLen, hiddenSize] + * @return atb::Status + */ + atb::Status LayerNormOperation::InferShape( + const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc + ) const + { + ATB_SPEED_LOG_DEBUG(opName_ << " InferShape start"); + outTensorDesc.at(0).format = inTensorDesc.at(0).format; + outTensorDesc.at(0).dtype = inTensorDesc.at(0).dtype; + outTensorDesc.at(0).shape.dimNum = inTensorDesc.at(0).shape.dimNum; + + ATB_SPEED_LOG_DEBUG("Check " << opName_ << " input dimNum=" << inTensorDesc.at(0).shape.dimNum); + for (uint64_t dim = 0; dim < inTensorDesc.at(0).shape.dimNum; ++dim) { + ATB_SPEED_LOG_DEBUG("input dim" << dim << " shape=" << inTensorDesc.at(0).shape.dims[dim]); + outTensorDesc.at(0).shape.dims[dim] = inTensorDesc.at(0).shape.dims[dim]; + } + + ATB_SPEED_LOG_DEBUG(opName_ << " InferShape end"); + return atb::NO_ERROR; + } + + uint32_t LayerNormOperation::GetInputNum() const + { + return NUM3; // inputTensorNum = 3 + } + + uint32_t LayerNormOperation::GetOutputNum() const + { + return NUM1; // outputTensorNum = 1 + } + + atb::Status LayerNormOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) + { + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret; + + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; + } + + atb::Status LayerNormOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) + { + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.inTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; + } + + atb::Status LayerNormOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) + { + AclNNVariantPack &aclNnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclNnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclNnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.outTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " outTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclNnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; + } + + std::shared_ptr LayerNormOperation::CreateTensor(atb::Tensor atbTensor, size_t tensorIdx) + { + ATB_SPEED_LOG_DEBUG(opName_ << " CreateTensor start"); + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = static_cast(tensorIdx); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(atbTensor); + ATB_SPEED_LOG_DEBUG(opName_ << " tensor dtype: " << squeezedAtbTensor.desc.dtype); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + squeezedAtbTensor.desc.shape.dims, + squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.desc.dtype, + aclnnTensor->strides.data(), + 0, + squeezedAtbTensor.desc.format, + squeezedAtbTensor.desc.shape.dims, + squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.deviceData); + ATB_SPEED_LOG_DEBUG(opName_ << " CreateTensor end"); + return aclnnTensor; + } + + int LayerNormOperation::SetAclNNWorkspaceExecutor() + { + ATB_SPEED_LOG_DEBUG( + opName_ << " SetAclNNWorkspaceExecutor start" + << ", layerNormEps: " << param_.layerNormEps + << ", beginNormAxis: " << param_.beginNormAxis + << ", layerNormImplMode: " << param_.layerNormImplMode + ); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int beginNormAxis = + param_.beginNormAxis < 0 ? + aclnnVariantPack.aclInTensors.at(0)->atbTensor.desc.shape.dimNum + param_.beginNormAxis : + param_.beginNormAxis; + uint64_t normalizedMaxDimNum = static_cast(beginNormAxis + param_.normAxes); + uint64_t inTensorDimNum = aclnnVariantPack.aclInTensors.at(0)->atbTensor.desc.shape.dimNum; + if (normalizedMaxDimNum > inTensorDimNum) { + std::stringstream ss; + ss << this->opName_ << " normalized max dimNum " << normalizedMaxDimNum + << " > inTensor dimNum " << inTensorDimNum; + ATB_SPEED_LOG_ERROR( + this->opName_ << " normalized max dimNum " << normalizedMaxDimNum + << " > inTensor dimNum " << inTensorDimNum; + ); + throw std::runtime_error(ss.str()); + } + int64_t normalizedShapeValue[param_.normAxes]; + for (int i = 0; i < param_.normAxes; ++i) { + normalizedShapeValue[i] = aclnnVariantPack.aclInTensors.at(0)->atbTensor.desc.shape.dims[ + beginNormAxis + i + ]; + } + aclIntArray *normalizedShape = aclCreateIntArray(normalizedShapeValue, param_.normAxes); + int ret = aclnnLayerNormWithImplModeGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // input + normalizedShape, // normalizedShape + aclnnVariantPack.aclInTensors.at(1)->tensor, // weight + param_.hasBias ? aclnnVariantPack.aclInTensors.at(2)->tensor : nullptr, // bias + param_.layerNormEps, // eps + aclnnVariantPack.aclOutTensors.at(0)->tensor, // out + nullptr, // meanOut + nullptr, // rstdOut + static_cast(param_.layerNormImplMode), // implMode + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG( + opName_ << " SetAclNNWorkspaceExecutor end" + << ", ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor + ); + return ret; + } + + int LayerNormOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) + { + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnLayerNormWithImplMode( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end" << ", ret: " << ret); + return ret; + } + +} // namespace atb_speed::common diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/layer_norm_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/layer_norm_operation.h new file mode 100644 index 00000000..8a0a75f2 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/layer_norm_operation.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_LAYER_NORM_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_LAYER_NORM_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" + + +namespace atb_speed::common { + + /// A struct defines `aclnnLayerNormWithImplModeGetWorkspaceSize` operation parameter. + struct AclNNLayerNormParam { + /// Indicates a value added to the denominator for numerical stability. + float layerNormEps = 0; + /// Indicates the start of normalization axis. + int beginNormAxis = 0; + /// Indicates the number of normalization axes. + int normAxes = 1; + /// Indicates the accuracy implementation mode in execution. + /// + /// 0: high accuracy mode. + /// 1: high performance mode. + /// 2: keep dtype of `float16` in execution. + int64_t layerNormImplMode = 0; + /// Indicates whether inputs include a bias tensor. + bool hasBias = true; + }; + + /// This class defines a matrix operation that applies Layer Normalization over a mini-batch of inputs. + /// + /// This class makes use of `aclnnLayerNormGetWorkspaceSize` and `aclnnLayerNormWithImplModeGetWorkspaceSize` + /// from the AscendCL API. + /// + /// Operation's Inputs: \n + /// | Name | Dtype | Shape | \n + /// |--------|--------------------------|-------------------------| \n + /// | input | float32/float16/bfloat16 | [-1,…,-1] | \n + /// | weight | float32/float16/bfloat16 | [beginNormAxis:]/[1:-1] | \n + /// | bias | float32/float16/bfloat16 | [beginNormAxis:]/[1:-1] | \n + /// + /// Operation's Outputs: \n + /// | Name | Dtype | Shape | \n + /// |--------|--------------------------|-------------------------| \n + /// | output | float32/float16/bfloat16 | [-1,…,-1] | \n + /// + /// Example: + /// \code + /// enum TensorIdx : uint32_t { + /// IN_INPUT = 0, + /// IN_WEIGHT, + /// IN_BIAS, + /// OUT, + /// }; + /// + /// atb::Node layerNormNode; + /// AclNNLayerNormParam layerNormParam; + /// layerNormParam.layerNormEps = 1e-5; + /// layerNormParam.beginNormAxis = -1; + /// layerNormParam.normAxes = 1; + /// layerNormParam.hasBias = true; + /// layerNormNode.inTensorIds = { IN_INPUT, IN_WEIGHT, IN_BIAS }; + /// layerNormNode.outTensorIds = { OUT }; + /// layerNormNode.operation = new atb_speed::common::LayerNormOperation("layerNormNode", layerNormParam); + /// + /// atb::GraphParam opGraph; + /// opGraph.nodes.push_back(layerNormNode); + /// \endcode + class LayerNormOperation : public AclNNOperation { + public: + explicit LayerNormOperation(const std::string &name, AclNNLayerNormParam param); + ~LayerNormOperation() override; + atb::Status InferShape( + const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc + ) const override; + [[nodiscard]] uint32_t GetInputNum() const override; + [[nodiscard]] uint32_t GetOutputNum() const override; + + protected: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + atb::Status CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + virtual std::shared_ptr CreateTensor(atb::Tensor atbTensor, size_t tensorIdx); + + private: + AclNNLayerNormParam param_; + std::string opName_; + }; +} // namespace atb_speed::common + +#endif // ATB_SPEED_PLUGIN_ACLNN_LAYER_NORM_OPERATION_H diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/len_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/len_operation.cpp new file mode 100644 index 00000000..f375799a --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/len_operation.cpp @@ -0,0 +1,144 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +*/ + +#include "len_operation.h" +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "aclnnop/aclnn_range.h" +#include "atb_speed/log.h" + +namespace atb_speed { +namespace common { + +LenOperation::LenOperation(const std::string &name) : AclNNOperation(name) {} + +LenOperation::~LenOperation() +{ + ATB_SPEED_LOG_DEBUG("LenOperation deconstruct"); + this->DestroyOperation(); +} + +uint32_t LenOperation::GetInputNum() const { return NUM1; } + +uint32_t LenOperation::GetOutputNum() const { return NUM1; } + +atb::Status LenOperation::InferShape( + const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " LenOperation infer shape start"); + outTensorDesc.at(0).format = inTensorDesc.at(0).format; + outTensorDesc.at(0).dtype = aclDataType::ACL_INT32; + outTensorDesc.at(0).shape.dimNum = 1; + outTensorDesc.at(0).shape.dims[0] = 1; + ATB_SPEED_LOG_DEBUG(opName_ << "LenOperation InferShape end"); + + return atb::NO_ERROR; +} + +int LenOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret; + + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +int LenOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.inTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + + +int LenOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclNnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclNnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclNnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.outTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " outTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclNnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +std::shared_ptr LenOperation::CreateTensor(atb::Tensor atbTensor, int tensorIdx) const +{ + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + aclnnTensor->tensorIdx = tensorIdx; + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, + atbTensor.desc.dtype, + aclnnTensor->strides.data(), + 0, + atbTensor.desc.format, + atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, + atbTensor.deviceData); + return aclnnTensor; +} + +int LenOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + auto start = aclnnVariantPack.aclInTensors.at(DIM0)->atbTensor.desc.shape.dims[DIM0]; + auto end = start + 1; + auto step = 1; + + int ret = aclnnRangeGetWorkspaceSize( + aclCreateScalar(&start, aclDataType::ACL_INT32), + aclCreateScalar(&end, aclDataType::ACL_INT32), + aclCreateScalar(&step, aclDataType::ACL_INT32), + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG( + opName_ << " SetAclNNWorkspaceExecutor end" + << ", ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor + ); + return ret; +} + +int LenOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnRange(workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end, ret:" << ret); + return ret; +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/len_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/len_operation.h new file mode 100644 index 00000000..54efdb18 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/len_operation.h @@ -0,0 +1,33 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +*/ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_LEN_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_LEN_OPERATION_H + +#include "aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { + +class LenOperation : public AclNNOperation { +public: + explicit LenOperation(const std::string &name); + ~LenOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx) const; +}; +} // namespace common +} // namespace atb_speed + +#endif // ATB_SPEED_PLUGIN_ACLNN_Len_OPERATION_H diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_allreduce_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_allreduce_operation.cpp new file mode 100644 index 00000000..13ec83cf --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_allreduce_operation.cpp @@ -0,0 +1,186 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "matmul_allreduce_operation.h" +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_matmul_all_reduce.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" + +namespace atb_speed { +namespace common { + +MatmulAllreduceOperation::MatmulAllreduceOperation(const std::string &name, HcclComm hcommInfo) + : AclNNOperation(name), hcommInfo_(hcommInfo) +{ + HcclGetCommName(hcommInfo_, this->hcommName); +} + +MatmulAllreduceOperation::~MatmulAllreduceOperation() {} + +atb::Status MatmulAllreduceOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum; + + if (inTensorDescs.at(0).shape.dimNum == DIM3) { + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + outTensorDescs.at(DIM0).shape.dims[DIM2] = inTensorDescs.at(DIM1).shape.dims[DIM0]; + } else if (inTensorDescs.at(0).shape.dimNum == DIM2) { + outTensorDescs.at(0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(0).shape.dims[DIM1] = inTensorDescs.at(DIM1).shape.dims[DIM1]; + } else { + ATB_SPEED_LOG_ERROR(opName_ << " invalid dim num:" << inTensorDescs.at(DIM0).shape.dimNum); + } + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return 0; +} + +uint32_t MatmulAllreduceOperation::GetInputNum() const { return NUM2; } + +uint32_t MatmulAllreduceOperation::GetOutputNum() const { return NUM1; } + +int MatmulAllreduceOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret = 0; + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + return ret; + } + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +int MatmulAllreduceOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->tensorListidx = -1; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor squeezedAtbTensor = variantPack.inTensors.at(i); + + if (false) { + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + std::vector shapeT(4); // 4: dimNu + shapeT[DIM0] = squeezedAtbTensor.desc.shape.dims[1] / 16; // 16: NZ_FORMAT + shapeT[DIM1] = squeezedAtbTensor.desc.shape.dims[0] / 16; // 16: NZ_FORMAT + shapeT[DIM2] = 16; // 16: NZ_FORMAT + shapeT[DIM3] = 16; // 16: NZ_FORMAT + aclnnTensor->tensor = aclCreateTensor(shapeT.data(), + 4, // 4: dimNum + squeezedAtbTensor.desc.dtype, + aclnnTensor->strides.data(), + 0, + squeezedAtbTensor.desc.format, + shapeT.data(), + 4, // 4: dimNum + squeezedAtbTensor.deviceData); + } else { + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor(squeezedAtbTensor.desc.shape.dims, + squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.desc.dtype, + aclnnTensor->strides.data(), + 0, + squeezedAtbTensor.desc.format, + squeezedAtbTensor.desc.shape.dims, + squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.deviceData); + } + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int MatmulAllreduceOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->tensorListidx = -1; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor atbTensor = variantPack.outTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, atbTensor.desc.dtype, + aclnnTensor->strides.data(), 0, atbTensor.desc.format, atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, atbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " OutTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int MatmulAllreduceOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnMatmulAllReduceGetWorkspaceSize start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnMatmulAllReduceGetWorkspaceSize(aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + aclnnVariantPack.aclInTensors.at(DIM1)->tensor, + nullptr, + this->hcommName, + "sum", + 0, + 1, + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnMatmulAllReduceGetWorkspaceSize end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int MatmulAllreduceOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnMatmulAllReduce start"); + int ret = aclnnMatmulAllReduce(workspace, this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnMatmulAllReduce end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_allreduce_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_allreduce_operation.h new file mode 100644 index 00000000..99e2cd7f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_allreduce_operation.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "Licenselianc + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_MATMUL_ALLREDUCE_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_MATMUL_ALLREDUCE_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { +class MatmulAllreduceOperation : public AclNNOperation { +public: + explicit MatmulAllreduceOperation(const std::string &name, HcclComm hcommInfo); + ~MatmulAllreduceOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + HcclComm hcommInfo_; + char hcommName[128]; +}; +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_operation.cpp new file mode 100644 index 00000000..ac62c22a --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_operation.cpp @@ -0,0 +1,227 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" +#include "operations/aclnn/core/acl_nn_operation.h" +#include "aclnnop/aclnn_addmm.h" +#include "matmul_operation.h" + +namespace atb_speed { +namespace common { + +MatmulOperation::MatmulOperation( + const std::string &name, + AclNNMatmulParam param) : AclNNOperation(name), param_(param) {} + +MatmulOperation::~MatmulOperation() +{ + ATB_SPEED_LOG_DEBUG("MatmulOperation deconstructor"); + this->DestroyOperation(); +} + +atb::Status MatmulOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "infer shape start"); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + if (param_.outDataType == ACL_BF16 || inTensorDescs.at(DIM0).dtype == ACL_BF16) { + outTensorDescs.at(DIM0).dtype = ACL_BF16; + } else { + outTensorDescs.at(DIM0).dtype = inTensorDescs.at(DIM0).dtype; + } + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + int nDim = param_.transposeB ? DIM0 : DIM1; // inTensorDescs.at(DIM1).shape.dimNum 为 2 + if (inTensorDescs.at(0).shape.dimNum == DIM3) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 3] CHECK " << opName_ << " inputs shape: [input0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0] << ", " + << inTensorDescs.at(DIM0).shape.dims[DIM1] << ", " << inTensorDescs.at(DIM0).shape.dims[DIM2]); + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 3] CHECK " << opName_ << " inputs shape: [input1]" + << inTensorDescs.at(DIM1).shape.dims[DIM0] << ", " << inTensorDescs.at(DIM1).shape.dims[DIM1]); + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + outTensorDescs.at(DIM0).shape.dims[DIM2] = inTensorDescs.at(DIM1).shape.dims[nDim]; + } else if (inTensorDescs.at(0).shape.dimNum == DIM2) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 2] CHECK " << opName_ << " inputs shape: [input0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0] << ", " << inTensorDescs.at(DIM0).shape.dims[DIM1]); + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 2] CHECK " << opName_ << " inputs shape: [input1]" + << inTensorDescs.at(DIM1).shape.dims[DIM0] << ", " << inTensorDescs.at(DIM1).shape.dims[DIM1]); + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM1).shape.dims[nDim]; + } else { + ATB_SPEED_LOG_ERROR(opName_ << " invalid dim num:" << inTensorDescs.at(DIM0).shape.dimNum); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return 0; +} + +uint32_t MatmulOperation::GetInputNum() const +{ + uint32_t inputNum = DIM2; + if (param_.hasBias) { + inputNum += DIM1; + } + return inputNum; +} + +uint32_t MatmulOperation::GetOutputNum() const +{ + return DIM1; +} + +int MatmulOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret = 0; + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + return ret; + } + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +atb::Dims MatmulOperation::GetWeightStorageShape(const atb::TensorDesc atbTensorDesc) const +{ + atb::Dims storageTensorDims = atbTensorDesc.shape; // ND格式下,storageShape和originalShape一致 + if (atbTensorDesc.format == ACL_FORMAT_FRACTAL_NZ) { + // nz格式 (k, n) => (n / 16, k / 16, 16, 16) + // nz格式 (n, k) => (k / 16, n / 16, 16, 16) + storageTensorDims.dimNum = NUM4; // 4维 + auto dim0 = atbTensorDesc.shape.dims[DIM0]; + uint32_t blockSize = 16; + storageTensorDims.dims[DIM0] = atbTensorDesc.shape.dims[DIM1] / blockSize; + storageTensorDims.dims[DIM1] = dim0 / blockSize; + storageTensorDims.dims[DIM2] = blockSize; + storageTensorDims.dims[DIM3] = blockSize; + } + return storageTensorDims; +} + +int MatmulOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); i++) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(variantPack.inTensors.at(i)); + + // StorageShape + atb::Dims storageTensorDims = GetWeightStorageShape(squeezedAtbTensor.desc); + + // ViewShape and Stride + atb::Dims viewDims = squeezedAtbTensor.desc.shape; + if (i == 1 && this->param_.transposeB) { + aclnnTensor->strides = GetTransposeTensorStride(viewDims); + viewDims.dims[DIM0] = squeezedAtbTensor.desc.shape.dims[DIM1]; + viewDims.dims[DIM1] = squeezedAtbTensor.desc.shape.dims[DIM0]; + } else { + aclnnTensor->strides = GetCopyTensorStride(viewDims); + } + + aclnnTensor->tensor = aclCreateTensor( + viewDims.dims, viewDims.dimNum, squeezedAtbTensor.desc.dtype, + aclnnTensor->strides.data(), 0, squeezedAtbTensor.desc.format, + storageTensorDims.dims, storageTensorDims.dimNum, squeezedAtbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int MatmulOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(variantPack.outTensors.at(i)); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + squeezedAtbTensor.desc.shape.dims, squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.desc.dtype, aclnnTensor->strides.data(), 0, + squeezedAtbTensor.desc.format, squeezedAtbTensor.desc.shape.dims, + squeezedAtbTensor.desc.shape.dimNum, squeezedAtbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " OutTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int MatmulOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + float zeroValue = 0.0f; + float oneValue = 1.0f; + aclScalar* betaZero = aclCreateScalar(&zeroValue, aclDataType::ACL_FLOAT); + aclScalar* betaOne = aclCreateScalar(&oneValue, aclDataType::ACL_FLOAT); + + int ret = aclnnAddmmGetWorkspaceSize( + param_.hasBias ? aclnnVariantPack.aclInTensors.at(DIM2)->tensor + : aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + aclnnVariantPack.aclInTensors.at(DIM1)->tensor, + param_.hasBias ? betaOne : betaZero, betaOne, + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, 0, + &this->aclnnOpCache_->workspaceSize, &this->aclnnOpCache_->aclExecutor); + + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int MatmulOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddmm start"); + int ret = aclnnAddmm( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddmm end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_operation.h new file mode 100644 index 00000000..44c90e06 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/matmul_operation.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_MATMUL_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_MATMUL_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "aclnnop/aclnn_addmm.h" + +namespace atb_speed { +namespace common { + +struct AclNNMatmulParam { + bool transposeB = true; // atb LinearParam transposeB 默认为 true + bool hasBias = false; + aclDataType outDataType = ACL_FLOAT16; +}; + +class MatmulOperation : public AclNNOperation { +public: + explicit MatmulOperation(const std::string &name, AclNNMatmulParam param); + ~MatmulOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + + atb::Dims GetWeightStorageShape(const atb::TensorDesc atbTensorDesc) const; + + AclNNMatmulParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_MATMUL_OPERATION_H diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/max_v2_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/max_v2_operation.cpp new file mode 100644 index 00000000..a064bc7f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/max_v2_operation.cpp @@ -0,0 +1,165 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "max_v2_operation.h" +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "aclnnop/aclnn_max_v2.h" +#include "atb_speed/log.h" + +namespace atb_speed { +namespace common { + +MaxV2Operation::MaxV2Operation(const std::string &name) : AclNNOperation(name) {} + +MaxV2Operation::MaxV2Operation(const std::string &name, atb_speed::common::AclNNMaxV2Param param) + : AclNNOperation(name), param_(param) +{ +} + +MaxV2Operation::~MaxV2Operation() +{ + ATB_SPEED_LOG_DEBUG("MaxV2Operation deconstruct"); + this->DestroyOperation(); +} + +uint32_t MaxV2Operation::GetInputNum() const { return NUM1; } + +uint32_t MaxV2Operation::GetOutputNum() const { return NUM1; } + +atb::Status MaxV2Operation::InferShape(const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MaxV2Operation infer shape start"); + outTensorDesc.at(0).format = inTensorDesc.at(0).format; + outTensorDesc.at(0).dtype = inTensorDesc.at(0).dtype; + uint32_t inputDimNum = inTensorDesc.at(0).shape.dimNum; + uint32_t outputDimNum = inputDimNum; + uint32_t realDim = this->param_.dims[0] < 0 ? this->param_.dims[0] + inputDimNum : this->param_.dims[0]; + + if (!param_.keepdim) { + outputDimNum -= 1; + } + outTensorDesc.at(0).shape.dimNum = outputDimNum; + + uint32_t j = 0; + for (uint32_t i = 0; i < outputDimNum; ++i) { + if (i == realDim && param_.keepdim) { + outTensorDesc.at(0).shape.dims[i] = 1; + j++; + } else { + outTensorDesc.at(0).shape.dims[j++] = inTensorDesc.at(0).shape.dims[i]; + } + } + + ATB_SPEED_LOG_DEBUG(opName_ << "MaxV2Operation InferShape end"); + + return atb::NO_ERROR; +} + +int MaxV2Operation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret; + + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +int MaxV2Operation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.inTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + + +int MaxV2Operation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclNnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclNnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclNnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.outTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " outTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclNnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +std::shared_ptr MaxV2Operation::CreateTensor(atb::Tensor atbTensor, int tensorIdx) const +{ + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + aclnnTensor->tensorIdx = tensorIdx; + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, + atbTensor.desc.dtype, + aclnnTensor->strides.data(), + 0, + atbTensor.desc.format, + atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, + atbTensor.deviceData); + return aclnnTensor; +} + +int MaxV2Operation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclIntArray* dims = aclCreateIntArray(this->param_.dims.data(), this->param_.dims.size()); + int ret = aclnnMaxV2GetWorkspaceSize(aclnnVariantPack.aclInTensors.at(0)->tensor, dims, + this->param_.keepdim, false, aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int MaxV2Operation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnMaxV2(workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end, ret:" << ret); + return ret; +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/max_v2_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/max_v2_operation.h new file mode 100644 index 00000000..45ac2f4c --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/max_v2_operation.h @@ -0,0 +1,52 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. + */ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_MAX_V2_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_MAX_V2_OPERATION_H + +#include +#include "aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { +struct AclNNMaxV2Param { + std::vector dims = {-1}; + bool keepdim = false; +}; +class MaxV2Operation : public AclNNOperation { +public: + explicit MaxV2Operation(const std::string &name); + explicit MaxV2Operation(const std::string &name, AclNNMaxV2Param param); + ~MaxV2Operation() override; + atb::Status InferShape(const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx) const; +private: + AclNNMaxV2Param param_; +}; +} // namespace common +} // namespace atb_speed + +#endif // ATB_SPEED_PLUGIN_ACLNN_MAX_V2_OPERATION_H diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/minimum_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/minimum_operation.cpp new file mode 100644 index 00000000..749c99f4 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/minimum_operation.cpp @@ -0,0 +1,140 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +*/ + +#include "minimum_operation.h" +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "aclnnop/aclnn_minimum.h" +#include "atb_speed/log.h" + +namespace atb_speed { +namespace common { + +MinimumOperation::MinimumOperation(const std::string &name) : AclNNOperation(name) {} + +MinimumOperation::~MinimumOperation() +{ + ATB_SPEED_LOG_DEBUG("MinimumOperation deconstruct"); + this->DestroyOperation(); +} + +uint32_t MinimumOperation::GetInputNum() const { return NUM2; } + +uint32_t MinimumOperation::GetOutputNum() const { return NUM1; } + +atb::Status MinimumOperation::InferShape( + const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MinimumOperation infer shape start"); + outTensorDesc.at(0).format = inTensorDesc.at(0).format; + outTensorDesc.at(0).dtype = inTensorDesc.at(0).dtype; + outTensorDesc.at(0).shape.dimNum = inTensorDesc.at(0).shape.dimNum; + for (uint64_t i = 0; i < inTensorDesc.at(0).shape.dimNum; ++i) { + outTensorDesc.at(0).shape.dims[i] = inTensorDesc.at(0).shape.dims[i]; + } + ATB_SPEED_LOG_DEBUG(opName_ << "MinimumOperation InferShape end"); + + return atb::NO_ERROR; +} + +int MinimumOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret; + + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +int MinimumOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.inTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + + +int MinimumOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclNnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclNnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclNnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.outTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " outTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclNnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +std::shared_ptr MinimumOperation::CreateTensor(atb::Tensor atbTensor, int tensorIdx) const +{ + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + aclnnTensor->tensorIdx = tensorIdx; + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, + atbTensor.desc.dtype, + aclnnTensor->strides.data(), + 0, + atbTensor.desc.format, + atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, + atbTensor.deviceData); + return aclnnTensor; +} + +int MinimumOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnMinimumGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // self + aclnnVariantPack.aclInTensors.at(1)->tensor, // other + aclnnVariantPack.aclOutTensors.at(0)->tensor, // out + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG( + opName_ << " SetAclNNWorkspaceExecutor end" + << ", ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor + ); + return ret; +} + +int MinimumOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnMinimum(workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end, ret:" << ret); + return ret; +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/minimum_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/minimum_operation.h new file mode 100644 index 00000000..d8238ee6 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/minimum_operation.h @@ -0,0 +1,33 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +*/ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_MINIMUM_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_MINIMUM_OPERATION_H + +#include "aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { + +class MinimumOperation : public AclNNOperation { +public: + explicit MinimumOperation(const std::string &name); + ~MinimumOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx) const; +}; +} // namespace common +} // namespace atb_speed + +#endif // ATB_SPEED_PLUGIN_ACLNN_MINIMUM_OPERATION_H diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_compute_expert_tokens_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_compute_expert_tokens_operation.cpp new file mode 100644 index 00000000..25218df8 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_compute_expert_tokens_operation.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_moe_compute_expert_tokens.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" +#include "moe_compute_expert_tokens_operation.h" + +namespace atb_speed { +namespace common { + +MoeComputeExpertTokensOperation::MoeComputeExpertTokensOperation( + const std::string &name, MoeComputeExpertTokensParam param) : AclNNOperation(name), param_(param) {} +MoeComputeExpertTokensOperation::~MoeComputeExpertTokensOperation() {} + +atb::Status MoeComputeExpertTokensOperation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "MoeComputeExpertTokensOperation infer shape start"); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = inTensorDescs.at(DIM0).dtype; + outTensorDescs.at(DIM0).shape.dimNum = DIM1; + + ATB_SPEED_LOG_DEBUG(opName_ + << "MoeComputeExpertTokensOperation infer shape origin " + << "inTensorDescs.at(DIM0).shape.dims[DIM0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0]); + outTensorDescs.at(DIM0).shape.dims[DIM0] = param_.expertNum; + + ATB_SPEED_LOG_DEBUG(opName_ << "MoeComputeExpertTokensOperation infer shape end"); + return 0; +} +uint32_t MoeComputeExpertTokensOperation::GetInputNum() const +{ + return DIM1; +} + +uint32_t MoeComputeExpertTokensOperation::GetOutputNum() const +{ + return DIM1; +} + +int MoeComputeExpertTokensOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeComputeExpertTokensOperation start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnMoeComputeExpertTokensGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + param_.expertNum, + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int MoeComputeExpertTokensOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeComputeExpertTokensOperation start"); + + int ret = aclnnMoeComputeExpertTokens( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " MoeComputeExpertTokensOperation end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_compute_expert_tokens_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_compute_expert_tokens_operation.h new file mode 100644 index 00000000..488c27f6 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_compute_expert_tokens_operation.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_MOE_COMPUTE_EXPERT_TOKENS_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_MOE_COMPUTE_EXPERT_TOKENS_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { + +struct MoeComputeExpertTokensParam { + /// The total number of experts utilized by the model + int32_t expertNum = 8; +}; + +/// This class defines an operator that computes the number of tokens that is processed by each expert +/// +/// This class makes uses of `aclnnMoeComputeExpertTokensGetWorkspaceSize` and `aclnnMoeComputeExpertTokens` +/// form the AscendCL API. +/// +/// Inputs to the operator +/// Name | Dtype | Shape | +/// -------------|-------|-------| +/// input | int32 | [m*k] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// -------------|-------|-------| +/// output | int64 | [e] | +/// Note: m is the length of input tokens, k is the number of experts selected for each token, +/// e is the total number of experts used by the model +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// INPUT = 0, +/// OUT, +/// }; +/// +/// atb::Node &expertTokenNode = opGraph.nodes.at(nodeId++); +/// expertTokenNode.operation = new atb_speed::common::MoeComputeExpertTokensOperation("ArgsortNode"); +/// expertTokenNode.inTensorIds = {INPUT}; +/// expertTokenNode.outTensorIds = {OUTPUT}; +/// \endcode + +class MoeComputeExpertTokensOperation : public AclNNOperation { +public: + explicit MoeComputeExpertTokensOperation(const std::string &name, MoeComputeExpertTokensParam param); + ~MoeComputeExpertTokensOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + MoeComputeExpertTokensParam param_; +}; + +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_MOE_TOPK_SOFTMAX_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_operation.cpp new file mode 100644 index 00000000..9c430615 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_operation.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "moe_distribute_combine_operation.h" +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_moe_distribute_combine.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" + +namespace atb_speed { +namespace common { + +MoeDistributeCombineOperation::MoeDistributeCombineOperation( + const std::string &name, MoeDistributeCombineParam param) : AclNNOperation(name), param_(param) {} +MoeDistributeCombineOperation::~MoeDistributeCombineOperation() {} + +atb::Status MoeDistributeCombineOperation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "MoeDistributeCombineOperation infer shape start"); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = inTensorDescs.at(DIM0).dtype; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + ATB_SPEED_LOG_DEBUG(opName_ + << "MoeDistributeCombineOperation infer shape origin inTensorDescs.at(DIM1).shape.dims[DIM0]" + << inTensorDescs.at(DIM1).shape.dims[DIM0]); + ATB_SPEED_LOG_DEBUG(opName_ + << "MoeDistributeCombineOperation infer shape origin inTensorDescs.at(DIM0).shape.dims[DIM1]" + << inTensorDescs.at(DIM0).shape.dims[DIM1]); + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM1).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + ATB_SPEED_LOG_DEBUG(opName_ << "MoeDistributeCombineOperation infer shape end"); + return 0; +} +uint32_t MoeDistributeCombineOperation::GetInputNum() const +{ + return NUM7; +} + +uint32_t MoeDistributeCombineOperation::GetOutputNum() const +{ + return DIM1; +} + +int32_t MoeDistributeCombineOperation::GetGlobalBS(const atb::TensorDesc &inTensorDesc) const +{ + int32_t worldSize = param_.epRankSize * std::max(param_.tpRankSize, 1); + if (param_.globalBS > 0) { + return param_.globalBS; + } + int32_t maxDecodeDpTokenSize = param_.maxDecodeDpTokenSize; + // if param_.maxDecodeDpTokenSize is not available,use in_padding_idx's DIM0 + if (maxDecodeDpTokenSize == 0) { + maxDecodeDpTokenSize = inTensorDesc.shape.dims[DIM0] / \ + std::min(param_.localMoeExpertNum, param_.topk) / worldSize; + } + return maxDecodeDpTokenSize * worldSize; +} + +int MoeDistributeCombineOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeCombineOperation start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + + aclnnVariantPack.aclInTensors.at(NUM6)->tensorIdx = NUM10; + int64_t globalBS = GetGlobalBS(aclnnVariantPack.aclInTensors.at(DIM0)->atbTensor.desc); + int ret = aclnnMoeDistributeCombineGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + aclnnVariantPack.aclInTensors.at(DIM1)->tensor, + aclnnVariantPack.aclInTensors.at(DIM2)->tensor, + aclnnVariantPack.aclInTensors.at(NUM3)->tensor, + aclnnVariantPack.aclInTensors.at(NUM4)->tensor, + aclnnVariantPack.aclInTensors.at(NUM5)->tensor, + nullptr, + nullptr, + nullptr, + nullptr, + aclnnVariantPack.aclInTensors.at(NUM6)->tensor, + param_.epCommName.data(), + param_.epRankSize, + param_.epRankId, + param_.moeExpertNum, + param_.tpCommName.data(), + param_.tpRankSize, + param_.tpRankId, + param_.expertSharedType, + 1, + param_.sharedExpertRankNum, + globalBS, + 0, + param_.commQuantMode, + 0, + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int MoeDistributeCombineOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeCombineOperation start"); + + int ret = aclnnMoeDistributeCombine( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeCombineOperation end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_operation.h new file mode 100644 index 00000000..0bf1c273 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_operation.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_MOE_DESTRIBUTE_COMBINE_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_MOE_DESTRIBUTE_COMBINE_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { + +struct MoeDistributeCombineParam { + int32_t epRankId = 0; + int32_t epRankSize = 1; + int32_t tpRankId = 0; + int32_t tpRankSize = 1; + int32_t expertSharedType = 0; + int32_t maxDecodeDpTokenSize = 0; + int64_t sharedExpertRankNum = 0; + int64_t moeExpertNum = 1; + int64_t localMoeExpertNum = 1; + int64_t topk = 8; + int64_t globalBS = 0; // tiling里处理成BS*world_size + std::string tpCommName; + std::string epCommName; + std::string rankTableFile = ""; + HcclComm hcclComm = nullptr; + int64_t commQuantMode = 0; +}; + +class MoeDistributeCombineOperation : public AclNNOperation { +public: + explicit MoeDistributeCombineOperation(const std::string &name, MoeDistributeCombineParam param); + ~MoeDistributeCombineOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + int32_t GetGlobalBS(const atb::TensorDesc &inTensorDesc) const; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + MoeDistributeCombineParam param_; +}; + +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_MOE_DESTRIBUTE_COMBINE_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_v2_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_v2_operation.cpp new file mode 100644 index 00000000..5038f878 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_v2_operation.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + */ +#include "moe_distribute_combine_v2_operation.h" +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_moe_distribute_combine_v2.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" + +namespace atb_speed { +namespace common { + +MoeDistributeCombineV2Operation::MoeDistributeCombineV2Operation( + const std::string &name, MoeDistributeCombineV2Param param) : AclNNOperation(name), param_(param) {} +MoeDistributeCombineV2Operation::~MoeDistributeCombineV2Operation() {} + +atb::Status MoeDistributeCombineV2Operation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "MoeDistributeCombineV2Operation infer shape start"); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = inTensorDescs.at(DIM0).dtype; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + ATB_SPEED_LOG_DEBUG(opName_ + << "MoeDistributeCombineV2Operation infer shape origin inTensorDescs.at(DIM1).shape.dims[DIM0]" + << inTensorDescs.at(DIM1).shape.dims[DIM0]); + ATB_SPEED_LOG_DEBUG(opName_ + << "MoeDistributeCombineV2Operation infer shape origin inTensorDescs.at(DIM0).shape.dims[DIM1]" + << inTensorDescs.at(DIM0).shape.dims[DIM1]); + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM1).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + ATB_SPEED_LOG_DEBUG(opName_ << "MoeDistributeCombineV2Operation infer shape end"); + return 0; +} +uint32_t MoeDistributeCombineV2Operation::GetInputNum() const +{ + return NUM7; // 7个intensor, 和combine_v1一致 +} + +uint32_t MoeDistributeCombineV2Operation::GetOutputNum() const +{ + return DIM1; +} + +int32_t MoeDistributeCombineV2Operation::GetGlobalBS(const atb::TensorDesc &inTensorDesc) const +{ + int32_t worldSize = param_.epRankSize * std::max(param_.tpRankSize, 1); + if (param_.globalBS > 0) { + return param_.globalBS; + } + int32_t maxDecodeDpTokenSize = param_.maxDecodeDpTokenSize; + // if param_.maxDecodeDpTokenSize is not available,use in_padding_idx's DIM0 + if (maxDecodeDpTokenSize == 0) { + maxDecodeDpTokenSize = inTensorDesc.shape.dims[DIM0] / \ + std::min(param_.localMoeExpertNum, param_.topk) / worldSize; + } + return maxDecodeDpTokenSize * worldSize; +} + +int MoeDistributeCombineV2Operation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeCombineV2Operation start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + + aclnnVariantPack.aclInTensors.at(NUM6)->tensorIdx = NUM10; + int64_t globalBS = GetGlobalBS(aclnnVariantPack.aclInTensors.at(DIM0)->atbTensor.desc); + int ret = aclnnMoeDistributeCombineV2GetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + aclnnVariantPack.aclInTensors.at(DIM1)->tensor, + aclnnVariantPack.aclInTensors.at(DIM2)->tensor, + aclnnVariantPack.aclInTensors.at(NUM3)->tensor, + aclnnVariantPack.aclInTensors.at(NUM4)->tensor, + aclnnVariantPack.aclInTensors.at(NUM5)->tensor, + nullptr, + nullptr, + nullptr, + nullptr, + aclnnVariantPack.aclInTensors.at(NUM6)->tensor, + nullptr, + param_.epCommName.data(), + param_.epRankSize, + param_.epRankId, + param_.moeExpertNum, + param_.tpCommName.data(), + param_.tpRankSize, + param_.tpRankId, + param_.expertSharedType, + 1, + param_.sharedExpertRankNum, + globalBS, + 0, + param_.commQuantMode, + 0, + param_.commAlg.data(), + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int MoeDistributeCombineV2Operation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeCombineV2Operation start"); + + int ret = aclnnMoeDistributeCombineV2( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeCombineV2Operation end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_v2_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_v2_operation.h new file mode 100644 index 00000000..67887f9d --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_combine_v2_operation.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_MOE_DESTRIBUTE_COMBINE_OPERATION_V2_H +#define ATB_SPEED_PLUGIN_ACLNN_MOE_DESTRIBUTE_COMBINE_OPERATION_V2_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { + +struct MoeDistributeCombineV2Param { + int32_t epRankId = 0; + int32_t epRankSize = 1; + int32_t tpRankId = 0; + int32_t tpRankSize = 1; + int32_t expertSharedType = 0; + int32_t maxDecodeDpTokenSize = 0; + int64_t sharedExpertRankNum = 0; + int64_t moeExpertNum = 1; + int64_t localMoeExpertNum = 1; + int64_t topk = 8; + int64_t globalBS = 0; // tiling里处理成BS*world_size + std::string tpCommName; + std::string epCommName; + std::string commAlg; + std::string rankTableFile = ""; + HcclComm hcclComm = nullptr; + int64_t commQuantMode = 0; +}; + +class MoeDistributeCombineV2Operation : public AclNNOperation { +public: + explicit MoeDistributeCombineV2Operation(const std::string &name, MoeDistributeCombineV2Param param); + ~MoeDistributeCombineV2Operation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + int32_t GetGlobalBS(const atb::TensorDesc &inTensorDesc) const; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + MoeDistributeCombineV2Param param_; +}; + +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_MOE_DESTRIBUTE_COMBINE_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_operation.cpp new file mode 100644 index 00000000..5d27bb97 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_operation.cpp @@ -0,0 +1,194 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "moe_distribute_dispatch_operation.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_moe_distribute_dispatch.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" + +namespace atb_speed { +namespace common { + +MoeDistributeDispatchOperation::MoeDistributeDispatchOperation( + const std::string &name, MoeDistributeDispatchParam param) : AclNNOperation(name), param_(param) {} +MoeDistributeDispatchOperation::~MoeDistributeDispatchOperation() {} + +atb::Status MoeDistributeDispatchOperation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "MoeDistributeDispatchOperation infer shape start"); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = param_.isQuant ? aclDataType::ACL_INT8 : inTensorDescs.at(DIM0).dtype; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + outTensorDescs.at(DIM1).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM1).dtype = aclDataType::ACL_FLOAT; + outTensorDescs.at(DIM1).shape.dimNum = DIM1; + + outTensorDescs.at(DIM2).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM2).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(DIM2).shape.dimNum = DIM1; + + outTensorDescs.at(DIM3).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM3).dtype = aclDataType::ACL_INT64; + outTensorDescs.at(DIM3).shape.dimNum = DIM1; + + outTensorDescs.at(NUM4).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(NUM4).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(NUM4).shape.dimNum = DIM1; + + outTensorDescs.at(NUM5).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(NUM5).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(NUM5).shape.dimNum = DIM1; + + outTensorDescs.at(NUM6).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(NUM6).dtype = aclDataType::ACL_FLOAT; + outTensorDescs.at(NUM6).shape.dimNum = DIM1; + + ATB_SPEED_LOG_DEBUG(opName_ + << "MoeDistributeDispatchOperation infer shape origin inTensorDescs.at(DIM0).shape.dims[DIM0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0]); + + int32_t globalBS = GetGlobalBS(inTensorDescs.at(NUM3)); + int32_t globalTokenNum = globalBS * std::min(param_.localMoeExpertNum, param_.topk); + + outTensorDescs.at(DIM0).shape.dims[DIM0] = param_.epRankId < param_.sharedExpertRankNum ? \ + globalTokenNum / param_.sharedExpertRankNum : globalTokenNum; // 后续对mm切分 + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + + outTensorDescs.at(DIM1).shape.dims[DIM0] = + param_.epRankId < param_.sharedExpertRankNum ? globalTokenNum / param_.sharedExpertRankNum : globalTokenNum; + + outTensorDescs.at(DIM2).shape.dims[DIM0] = inTensorDescs.at(DIM1).shape.dims[DIM0] * \ + inTensorDescs.at(DIM1).shape.dims[DIM1]; + + outTensorDescs.at(DIM3).shape.dims[DIM0] = param_.localMoeExpertNum; + + outTensorDescs.at(NUM4).shape.dims[DIM0] = param_.epRankSize * param_.localMoeExpertNum + \ + globalBS * param_.topk * (param_.epRankSize / NUM8) * NUM2; + + outTensorDescs.at(NUM5).shape.dims[DIM0] = 1; + + outTensorDescs.at(NUM6).shape.dims[DIM0] = + param_.epRankId < param_.sharedExpertRankNum ? globalTokenNum / param_.sharedExpertRankNum : globalTokenNum; + + ATB_SPEED_LOG_DEBUG(opName_ << "MoeDistributeDispatchOperation infer shape end"); + return 0; +} + +uint32_t MoeDistributeDispatchOperation::GetInputNum() const +{ + if (param_.quantSmooth) { + return NUM5; + } else { + return NUM4; + } +} + +uint32_t MoeDistributeDispatchOperation::GetOutputNum() const +{ + return NUM7; +} + +int32_t MoeDistributeDispatchOperation::GetGlobalBS(const atb::TensorDesc &inTensorDesc) const +{ + int32_t worldSize = param_.epRankSize * std::max(param_.tpRankSize, 1); + if (param_.globalBS > 0) { + return param_.globalBS; + } + int32_t maxDecodeDpTokenSize = param_.maxDecodeDpTokenSize; + // if param_.maxDecodeDpTokenSize is not available,use in_padding_idx's DIM0 + if (maxDecodeDpTokenSize == 0) { + maxDecodeDpTokenSize = inTensorDesc.shape.dims[DIM0]; + } + return maxDecodeDpTokenSize * worldSize; +} + +int MoeDistributeDispatchOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeDispatchOperation start"); + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeDispatchOperation create hcclComm"); + ATB_SPEED_LOG_DEBUG("param_.epCommName " << param_.epCommName << "param_.tpCommName " << param_.tpCommName + << "param_.epRankSize " << param_.epRankSize + << " param_.tpRankSize " << param_.tpRankSize + << " param_.epRankId " << param_.epRankId + << " param_.tpRankId " << param_.tpRankId + << " param_.expertSharedType " << param_.expertSharedType + << " param_.sharedExpertRankNum " << param_.sharedExpertRankNum << " param_.moeExpertNum " + << param_.moeExpertNum << "param_.quantMode " << param_.quantMode + << " param_.globalBS " << param_.globalBS); + + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + + aclnnVariantPack.aclInTensors.at(NUM2)->tensorIdx = NUM4; + aclnnVariantPack.aclInTensors.at(NUM3)->needUpdateTensorDataPtr = false; + int32_t globalBS = GetGlobalBS(aclnnVariantPack.aclInTensors.at(NUM3)->atbTensor.desc); + int ret = aclnnMoeDistributeDispatchGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + aclnnVariantPack.aclInTensors.at(DIM1)->tensor, + param_.quantSmooth ? aclnnVariantPack.aclInTensors.at(DIM2)->tensor : nullptr, + nullptr, + aclnnVariantPack.aclInTensors.at(NUM2)->tensor, + param_.epCommName.data(), + param_.epRankSize, + param_.epRankId, + param_.moeExpertNum, + param_.tpCommName.data(), + param_.tpRankSize, + param_.tpRankId, + param_.expertSharedType, + 1, + param_.sharedExpertRankNum, + param_.quantMode, + globalBS, + param_.expertTokenNumsType, + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM1)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM2)->tensor, + aclnnVariantPack.aclOutTensors.at(NUM3)->tensor, + aclnnVariantPack.aclOutTensors.at(NUM4)->tensor, + aclnnVariantPack.aclOutTensors.at(NUM5)->tensor, + aclnnVariantPack.aclOutTensors.at(NUM6)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int MoeDistributeDispatchOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeDispatchOperation start"); + + int ret = aclnnMoeDistributeDispatch( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeDispatchOperation end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_operation.h new file mode 100644 index 00000000..b0fdcc2d --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_operation.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_MOE_DESTRIBUTE_DISPATCH_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_MOE_DESTRIBUTE_DISPATCH_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { + +struct MoeDistributeDispatchParam { + int32_t epRankId = 0; + int32_t epRankSize = 1; + int32_t tpRankId = 0; + int32_t tpRankSize = 1; + int32_t expertSharedType = 0; + int32_t maxDecodeDpTokenSize = 0; + int64_t sharedExpertRankNum = 0; + int64_t moeExpertNum = 1; + int64_t localMoeExpertNum = 1; + int64_t topk = 8; + int64_t quantMode = 2; + int64_t globalBS = 0; // tiling里处理成BS*world_size + int64_t expertTokenNumsType = 0; + bool isQuant = false; + bool isSharedExpert = false; + bool quantSmooth = false; + std::string tpCommName; + std::string epCommName; + std::string rankTableFile = ""; +}; + +class MoeDistributeDispatchOperation : public AclNNOperation { +public: + explicit MoeDistributeDispatchOperation(const std::string &name, MoeDistributeDispatchParam param); + ~MoeDistributeDispatchOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + int32_t GetGlobalBS(const atb::TensorDesc &inTensorDesc) const; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + MoeDistributeDispatchParam param_; +}; + +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_MOE_DESTRIBUTE_DISPATCH__OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_v2_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_v2_operation.cpp new file mode 100644 index 00000000..6112ef94 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_v2_operation.cpp @@ -0,0 +1,184 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + */ +#include "moe_distribute_dispatch_v2_operation.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_moe_distribute_dispatch_v2.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" + +namespace atb_speed { +namespace common { + +MoeDistributeDispatchV2Operation::MoeDistributeDispatchV2Operation( + const std::string &name, MoeDistributeDispatchV2Param param) : AclNNOperation(name), param_(param) {} +MoeDistributeDispatchV2Operation::~MoeDistributeDispatchV2Operation() {} + +atb::Status MoeDistributeDispatchV2Operation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "MoeDistributeDispatchV2Operation infer shape start"); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = param_.isQuant ? aclDataType::ACL_INT8 : inTensorDescs.at(DIM0).dtype; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + outTensorDescs.at(DIM1).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM1).dtype = aclDataType::ACL_FLOAT; + outTensorDescs.at(DIM1).shape.dimNum = DIM1; + + outTensorDescs.at(DIM2).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM2).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(DIM2).shape.dimNum = DIM1; + + outTensorDescs.at(DIM3).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM3).dtype = aclDataType::ACL_INT64; + outTensorDescs.at(DIM3).shape.dimNum = DIM1; + + outTensorDescs.at(NUM4).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(NUM4).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(NUM4).shape.dimNum = DIM1; + + outTensorDescs.at(NUM5).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(NUM5).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(NUM5).shape.dimNum = DIM1; + + outTensorDescs.at(NUM6).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(NUM6).dtype = aclDataType::ACL_FLOAT; + outTensorDescs.at(NUM6).shape.dimNum = DIM1; + + ATB_SPEED_LOG_DEBUG(opName_ + << "MoeDistributeDispatchV2Operation infer shape origin inTensorDescs.at(DIM0).shape.dims[DIM0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0]); + + int32_t globalBS = GetGlobalBS(inTensorDescs.at(NUM3)); + int32_t globalTokenNum = globalBS * std::min(param_.localMoeExpertNum, param_.topk); + + outTensorDescs.at(DIM0).shape.dims[DIM0] = param_.epRankId < param_.sharedExpertRankNum ? \ + globalTokenNum / param_.sharedExpertRankNum : globalTokenNum; // 后续对mm切分 + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + + outTensorDescs.at(DIM1).shape.dims[DIM0] = + param_.epRankId < param_.sharedExpertRankNum ? globalTokenNum / param_.sharedExpertRankNum : globalTokenNum; + + outTensorDescs.at(DIM2).shape.dims[DIM0] = std::max(inTensorDescs.at(DIM1).shape.dims[DIM0] * \ + inTensorDescs.at(DIM1).shape.dims[DIM1], static_cast(globalTokenNum) * 128); // A3 shape: A * 128 + + outTensorDescs.at(DIM3).shape.dims[DIM0] = param_.localMoeExpertNum; + + outTensorDescs.at(NUM4).shape.dims[DIM0] = param_.epRankSize * param_.localMoeExpertNum + \ + globalBS * param_.topk * (param_.epRankSize / NUM8) * NUM2; + + outTensorDescs.at(NUM5).shape.dims[DIM0] = 1; + + outTensorDescs.at(NUM6).shape.dims[DIM0] = + param_.epRankId < param_.sharedExpertRankNum ? globalTokenNum / param_.sharedExpertRankNum : globalTokenNum; + + ATB_SPEED_LOG_DEBUG(opName_ << "MoeDistributeDispatchV2Operation infer shape end"); + return 0; +} + +uint32_t MoeDistributeDispatchV2Operation::GetInputNum() const +{ + if (param_.quantSmooth) { + return NUM5; + } else { + return NUM4; // 4个intensor: hiddenstates, selected_experts, expert_weight, padding_idx + } +} + +uint32_t MoeDistributeDispatchV2Operation::GetOutputNum() const +{ + return NUM7; // 7个outtensor,和dispatch_v1保持一致 +} + +int32_t MoeDistributeDispatchV2Operation::GetGlobalBS(const atb::TensorDesc &inTensorDesc) const +{ + int32_t worldSize = param_.epRankSize * std::max(param_.tpRankSize, 1); + if (param_.globalBS > 0) { + return param_.globalBS; + } + int32_t maxDecodeDpTokenSize = param_.maxDecodeDpTokenSize; + // if param_.maxDecodeDpTokenSize is not available,use in_padding_idx's DIM0 + if (maxDecodeDpTokenSize == 0) { + maxDecodeDpTokenSize = inTensorDesc.shape.dims[DIM0]; + } + return maxDecodeDpTokenSize * worldSize; +} + +int MoeDistributeDispatchV2Operation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeDispatchV2Operation start"); + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeDispatchV2Operation create hcclComm"); + ATB_SPEED_LOG_DEBUG("param_.epCommName " << param_.epCommName << "param_.tpCommName " << param_.tpCommName + << " param_.commAlg " << param_.commAlg + << " param_.epRankSize " << param_.epRankSize + << " param_.tpRankSize " << param_.tpRankSize + << " param_.epRankId " << param_.epRankId + << " param_.tpRankId " << param_.tpRankId + << " param_.expertSharedType " << param_.expertSharedType + << " param_.sharedExpertRankNum " << param_.sharedExpertRankNum << " param_.moeExpertNum " + << param_.moeExpertNum << "param_.quantMode " << param_.quantMode + << " param_.globalBS " << param_.globalBS); + + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + + aclnnVariantPack.aclInTensors.at(NUM2)->tensorIdx = NUM4; + aclnnVariantPack.aclInTensors.at(NUM3)->needUpdateTensorDataPtr = false; + int32_t globalBS = GetGlobalBS(aclnnVariantPack.aclInTensors.at(NUM3)->atbTensor.desc); + int ret = aclnnMoeDistributeDispatchV2GetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + aclnnVariantPack.aclInTensors.at(DIM1)->tensor, + param_.quantSmooth ? aclnnVariantPack.aclInTensors.at(DIM2)->tensor : nullptr, + nullptr, + aclnnVariantPack.aclInTensors.at(NUM2)->tensor, + param_.epCommName.data(), + param_.epRankSize, + param_.epRankId, + param_.moeExpertNum, + param_.tpCommName.data(), + param_.tpRankSize, + param_.tpRankId, + param_.expertSharedType, + 1, + param_.sharedExpertRankNum, + param_.quantMode, + globalBS, + param_.expertTokenNumsType, + param_.commAlg.data(), + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM1)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM2)->tensor, + aclnnVariantPack.aclOutTensors.at(NUM3)->tensor, + aclnnVariantPack.aclOutTensors.at(NUM4)->tensor, + aclnnVariantPack.aclOutTensors.at(NUM5)->tensor, + aclnnVariantPack.aclOutTensors.at(NUM6)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int MoeDistributeDispatchV2Operation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeDispatchV2Operation start"); + + int ret = aclnnMoeDistributeDispatchV2( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " MoeDistributeDispatchV2Operation end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_v2_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_v2_operation.h new file mode 100644 index 00000000..0e9abf71 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_distribute_dispatch_v2_operation.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_MOE_DESTRIBUTE_DISPATCH_OPERATION_V2_H +#define ATB_SPEED_PLUGIN_ACLNN_MOE_DESTRIBUTE_DISPATCH_OPERATION_V2_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { + +struct MoeDistributeDispatchV2Param { + int32_t epRankId = 0; + int32_t epRankSize = 1; + int32_t tpRankId = 0; + int32_t tpRankSize = 1; + int32_t expertSharedType = 0; + int32_t maxDecodeDpTokenSize = 0; + int64_t sharedExpertRankNum = 0; + int64_t moeExpertNum = 1; + int64_t localMoeExpertNum = 1; + int64_t topk = 8; + int64_t quantMode = 2; + int64_t globalBS = 0; // tiling里处理成BS*world_size + int64_t expertTokenNumsType = 0; + bool isQuant = false; + bool isSharedExpert = false; + bool quantSmooth = false; + std::string tpCommName; + std::string epCommName; + std::string commAlg; + std::string rankTableFile = ""; +}; + +class MoeDistributeDispatchV2Operation : public AclNNOperation { +public: + explicit MoeDistributeDispatchV2Operation(const std::string &name, MoeDistributeDispatchV2Param param); + ~MoeDistributeDispatchV2Operation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + int32_t GetGlobalBS(const atb::TensorDesc &inTensorDesc) const; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + MoeDistributeDispatchV2Param param_; +}; + +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_operation.cpp new file mode 100644 index 00000000..5f98668c --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_operation.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_moe_init_routing_v2.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" +#include "moe_init_routing_operation.h" + +namespace atb_speed { +namespace common { + +MoeInitRoutingOperation::MoeInitRoutingOperation( + const std::string &name, MoeInitRoutingParam param) : AclNNOperation(name), param_(param) {} +MoeInitRoutingOperation::~MoeInitRoutingOperation() {} + +atb::Status MoeInitRoutingOperation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "MoeInitRoutingOperation infer shape start"); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = inTensorDescs.at(DIM0).dtype; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + outTensorDescs.at(DIM1).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM1).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(DIM1).shape.dimNum = DIM1; + + outTensorDescs.at(DIM2).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM2).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(DIM2).shape.dimNum = DIM1; + + ATB_SPEED_LOG_DEBUG(opName_ << "MoeInitRoutingOperation infer shape origin inTensorDescs.at(DIM0).shape.dims[DIM0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0]); + int scaledTopk = param_.enableInitRoutingCutoff ? param_.scaledTopk : param_.topkNum; + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0] * scaledTopk; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + outTensorDescs.at(DIM1).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0] * param_.topkNum; + outTensorDescs.at(DIM2).shape.dims[DIM0] = param_.expertNum; + + ATB_SPEED_LOG_DEBUG(opName_ << "MoeInitRoutingOperation infer shape end"); + return 0; +} +uint32_t MoeInitRoutingOperation::GetInputNum() const +{ + return DIM2; +} + +uint32_t MoeInitRoutingOperation::GetOutputNum() const +{ + return DIM3; +} + +int MoeInitRoutingOperation::SetAclNNWorkspaceExecutor() +{ + int scaledTopk = param_.enableInitRoutingCutoff ? param_.scaledTopk : param_.topkNum; + ATB_SPEED_LOG_DEBUG(opName_ << " MoeInitRoutingOperation start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnMoeInitRoutingV2GetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + aclnnVariantPack.aclInTensors.at(DIM1)->tensor, + aclnnVariantPack.aclInTensors.at(DIM0)->atbTensor.desc.shape.dims[DIM0] * scaledTopk, + 0, param_.expertNum, 0, param_.expertTokensCoutOrCumsumFlag, false, + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM1)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM2)->tensor, + nullptr, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int MoeInitRoutingOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeInitRoutingOperation start"); + + int ret = aclnnMoeInitRoutingV2( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " MoeInitRoutingOperation end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_operation.h new file mode 100644 index 00000000..825cfbcb --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_operation.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_MOE_INIT_ROUTING_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_MOE_INIT_ROUTING_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { + +struct MoeInitRoutingParam { + /// The number of experts selected for each token + int32_t topkNum = 2; + /// The non-deepseek models do not have the scaledTopk feature enabled by default + int scaledTopk = -1; + bool enableInitRoutingCutoff = false; + /// The total number of experts utilized by the model + int32_t expertNum = 8; + int expertTokensCoutOrCumsumFlag = 1; +}; + +/// This class defines an operator that is used to gather and rearrange hidden states based +/// on the given list of selected experts of each token. +/// +/// This class makes uses of `aclnnMoeInitRoutingV2GetWorkspaceSize` and `aclnnMoeInitRoutingV2` +/// from the AscendCL API. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// -------------|---------------------|-------| +/// input | float16 or bfloat16 | [m,h] | +/// expertIdx | int32 | [m,k] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// -----------------------------|-------|---------| +/// expandedXOut | int32 | [m*k,h] | +/// expandedRowIdxOut | int32 | [m*k] | +/// expertTokensCountOrCumsumOut | int32 | [e] | +/// Note: e is the total number of experts utilized by the model +/// k is the number of experts selected for each token +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_EXPERTIDX, +/// OUT_SORTED_HIDDENSTATES, +/// OUT_ROWIDX, +/// OUT_GROUP_LIST +/// }; +/// +/// atb::Node &initRoutingNode = opGraph.nodes.at(nodeId++); +/// atb_speed::common::MoeInitRoutingParam initRoutingParam; +/// initRoutingParam.topkNum = param.topk; +/// initRoutingParam.expertNum = param.numOfExperts; +/// initRoutingNode.operation = new atb_speed::common::MoeInitRoutingOperation("MoeInitRoutingOperation", +/// initRoutingParam); +/// initRoutingNode.inTensorIds = {IN_PUT, IN_EXPERTIDX}; +/// initRoutingNode.outTensorIds = {OUT_SORTED_HIDDENSTATES, +/// OUT_ROWIDX, +/// OUT_GROUP_LIST}; +/// \endcode + +class MoeInitRoutingOperation : public AclNNOperation { +public: + explicit MoeInitRoutingOperation(const std::string &name, MoeInitRoutingParam param); + ~MoeInitRoutingOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + MoeInitRoutingParam param_; +}; + +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_MOE_TOPK_SOFTMAX_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_quant_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_quant_operation.cpp new file mode 100644 index 00000000..85c35a5c --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_quant_operation.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_moe_init_routing_quant_v2.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" +#include "moe_init_routing_quant_operation.h" + +namespace atb_speed { +namespace common { + +MoeInitRoutingQuantOperation::MoeInitRoutingQuantOperation( + const std::string &name, MoeInitRoutingQuantParam param) : AclNNOperation(name), param_(param) {} +MoeInitRoutingQuantOperation::~MoeInitRoutingQuantOperation() {} + +atb::Status MoeInitRoutingQuantOperation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "MoeInitRoutingQuantOperation infer shape start"); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = aclDataType::ACL_INT8; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + outTensorDescs.at(DIM1).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM1).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(DIM1).shape.dimNum = DIM1; + + outTensorDescs.at(DIM2).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM2).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(DIM2).shape.dimNum = DIM1; + + outTensorDescs.at(DIM3).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM3).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(DIM3).shape.dimNum = DIM1; + + outTensorDescs.at(NUM4).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(NUM4).dtype = aclDataType::ACL_FLOAT; + outTensorDescs.at(NUM4).shape.dimNum = DIM1; + + ATB_SPEED_LOG_DEBUG(opName_ << "MoeInitRoutingQuantOperation infer shape origin \ + inTensorDescs.at(DIM0).shape.dims[DIM0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0]); + int scaledTopk = param_.enableInitRoutingCutoff ? param_.scaledTopk : param_.topkNum; + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0] * scaledTopk; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + outTensorDescs.at(DIM1).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0] * param_.topkNum; + outTensorDescs.at(DIM2).shape.dims[DIM0] = param_.expertNum; + outTensorDescs.at(DIM3).shape.dims[DIM0] = param_.expertNum; + outTensorDescs.at(NUM4).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0] * scaledTopk; + + ATB_SPEED_LOG_DEBUG(opName_ << "MoeInitRoutingQuantOperation infer shape end"); + return 0; +} +uint32_t MoeInitRoutingQuantOperation::GetInputNum() const +{ + return DIM2; +} + +uint32_t MoeInitRoutingQuantOperation::GetOutputNum() const +{ + return NUM5; +} + +int MoeInitRoutingQuantOperation::SetAclNNWorkspaceExecutor() +{ + int scaledTopk = param_.enableInitRoutingCutoff ? param_.scaledTopk : param_.topkNum; + ATB_SPEED_LOG_DEBUG(opName_ << " MoeInitRoutingQuantOperation start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnMoeInitRoutingQuantV2GetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + aclnnVariantPack.aclInTensors.at(DIM1)->tensor, + nullptr, + nullptr, + aclnnVariantPack.aclInTensors.at(DIM0)->atbTensor.desc.shape.dims[DIM0] * scaledTopk, + 0, param_.expertNum, 0, param_.expertTokensCoutOrCumsumFlag, false, param_.quantMode, + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM1)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM2)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM3)->tensor, + aclnnVariantPack.aclOutTensors.at(NUM4)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int MoeInitRoutingQuantOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeInitRoutingQuantOperation start"); + + int ret = aclnnMoeInitRoutingQuantV2( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " MoeInitRoutingQuantOperation end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_quant_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_quant_operation.h new file mode 100644 index 00000000..ad2078a9 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_init_routing_quant_operation.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_MOE_INIT_ROUTING_QUANT_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_MOE_INIT_ROUTING_QUANT_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { + +struct MoeInitRoutingQuantParam { + int32_t topkNum = 2; /// The number of experts selected for each token + int scaledTopk = -1; /// The non-deepseek models do not have the scaledTopk feature enabled by default + bool enableInitRoutingCutoff = false; + int32_t expertNum = 8; /// The total number of experts utilized by the model + int32_t quantMode = 1; /// The quant mode: 0 is static quant and 1 is dynamic quant + int expertTokensCoutOrCumsumFlag = 1; +}; + +/// This calss defines an operator that is used to gather, rearrage and quantize hidden states based +/// on the given list of selected experts of each token. +/// +/// This class makes uses of `aclnnMoeInitRoutingQuantV2GetWorkspaceSize` and `aclnnMoeInitRoutingV2Quant` +/// from the AscendCL API. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// -------------|---------------------|-------| +/// input | float16 or bfloat16 | [m,h] | +/// expertIdx | int32 | [m,k] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// ------------------------------|---------------------|---------| +/// expandedXOut | int32 | [m*k,h] | +/// expandedRowIdxOut | int32 | [m*k] | +/// expertTokensCoutOrCumsumOut | int32 | [e] | +/// expertTokensBeforeCapacityOut | int32 | [e] | +/// dynamicQuantScaleOut | float16 or bfloat16 | [m*k] | +/// Note: e is the total number of experts utilized by the model +/// k is the number of experts selected for each token +/// +/// Example: +/// \code +/// enum InTensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_EXPERTIDX, +/// }; +/// +/// enum OutTensorIdx : uint32_t { +/// OUT_SORTED_HIDDENSTATES = 0, +/// OUT_ROWIDX, +/// OUT_GROUP_LIST, +/// OUT_EXPERT_TOKNENS_BEFORE_CAPACITY, +/// OUT_DYNAMIC_QUANT_SCALE +/// }; +/// +/// atb::Node &initRoutingNode = opGraph.nodes.at(nodeId++); +/// atb_speed::common::MoeInitRoutingQuantParam initRoutingParam; +/// initRoutingParam.topkNum = param.topk; +/// initRoutingParam.expertNum = param.numOfExperts; +/// initRoutingNode.operation = new atb_speed::common::MoeInitRoutingQuantOperation("MoeInitRoutingQuantOperation", +/// initRoutingParam); +/// initRoutingNode.inTensorIds = {IN_PUT, IN_EXPERTIDX}; +/// initRoutingNode.outTensorIds = {OUT_SORTED_HIDDENSTATES, +/// OUT_ROWIDX, +/// OUT_GROUP_LIST, +/// OUT_EXPERT_TOKNENS_BEFORE_CAPACITY, +/// OUT_DYNAMIC_QUANT_SCALE}; +/// \endcode + +class MoeInitRoutingQuantOperation : public AclNNOperation { +public: + explicit MoeInitRoutingQuantOperation(const std::string &name, MoeInitRoutingQuantParam param); + ~MoeInitRoutingQuantOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + MoeInitRoutingQuantParam param_; +}; + +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_MOE_INIT_ROUTING_QUANT_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_topk_softmax_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_topk_softmax_operation.cpp new file mode 100644 index 00000000..de9cbe01 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_topk_softmax_operation.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "moe_topk_softmax_operation.h" +#include +#include +#include +#include +#include "acl/acl.h" +#include "aclnnop/aclnn_moe_gating_top_k_softmax.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" + +namespace atb_speed { +namespace common { + +MoeTopkSoftmaxOperation::MoeTopkSoftmaxOperation( + const std::string &name, MoeTopkSoftmaxParam param) : AclNNOperation(name), param_(param) {} +MoeTopkSoftmaxOperation::~MoeTopkSoftmaxOperation() {} + +atb::Status MoeTopkSoftmaxOperation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "MoeTopkSoftmaxOperation infer shape start"); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = inTensorDescs.at(DIM0).dtype; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + outTensorDescs.at(DIM1).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM1).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(DIM1).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + outTensorDescs.at(DIM2).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM2).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(DIM2).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + ATB_SPEED_LOG_DEBUG(opName_ << "MoeTopkSoftmaxOperation infer shape origin inTensorDescs.at(DIM0).shape.dims[DIM0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0]); + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = param_.topkNum; + outTensorDescs.at(DIM1).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM1).shape.dims[DIM1] = param_.topkNum; + outTensorDescs.at(DIM2).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM2).shape.dims[DIM1] = param_.topkNum; + + ATB_SPEED_LOG_DEBUG(opName_ << "MoeTopkSoftmaxOperation infer shape end"); + return 0; +} +uint32_t MoeTopkSoftmaxOperation::GetInputNum() const +{ + return DIM1; +} + +uint32_t MoeTopkSoftmaxOperation::GetOutputNum() const +{ + return DIM3; +} + +int MoeTopkSoftmaxOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeTopkSoftmaxOperation start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnMoeGatingTopKSoftmaxGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + nullptr, + param_.topkNum, + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM1)->tensor, + aclnnVariantPack.aclOutTensors.at(DIM2)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int MoeTopkSoftmaxOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " MoeTopkSoftmaxOperation start"); + + int ret = aclnnMoeGatingTopKSoftmax( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " MoeTopkSoftmaxOperation end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_topk_softmax_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_topk_softmax_operation.h new file mode 100644 index 00000000..cd58b6dc --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moe_topk_softmax_operation.h @@ -0,0 +1,85 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_MOE_TOPK_SOFTMAX_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_MOE_TOPK_SOFTMAX_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { + +struct MoeTopkSoftmaxParam { + /// The number of experts selected for each token + int64_t topkNum = 2; +}; + +/// This class defines an operator that first applies softmax to each row of the input, and then +/// selects the top k greatest value. +/// +/// This class makes uses of `aclnnMoeGatingTopKSoftmaxGetWorkspaceSize` and `aclnnMoeGatingTopKSoftmax` +/// from the AscendCL API. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// ----------------|-----------------------|-------| +/// input | float16 or bfloat16 | [m,e] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// ----------------|-----------------------|-------| +/// output | float16 or bfloat16 | [m,k] | +/// expertIdx | int32 | [m,k] | +/// rowIdx | int32 | [m,k] | +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// OUT, +/// OUT_EXPERTIDX, +/// OUT_ROWIDX +/// }; +/// +/// atb::Node &topKNode = opGraph.nodes.at(nodeId++); +/// atb_speed::common::MoeTopkSoftmaxParam moeTopkSoftmaxParam; +/// moeTopkSoftmaxParam.topkNum = int64_t(param.num.at(0)); +/// topKNode.operation = new atb_speed::common::MoeTopkSoftmaxOperation("MoeTopkSoftmaxOperation", moeTopkSoftmaxParam); +/// topKNode.inTensorIds = {INPUT}; +/// topKNode.outTensorIds = {OUT, +/// OUT_EXPERTIDX, +/// OUT_ROWIDX}; +/// +/// \endcode + +class MoeTopkSoftmaxOperation : public AclNNOperation { +public: + explicit MoeTopkSoftmaxOperation(const std::string &name, MoeTopkSoftmaxParam param); + ~MoeTopkSoftmaxOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + MoeTopkSoftmaxParam param_; +}; + +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_MOE_TOPK_SOFTMAX_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moetoken_umpermute_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/moetoken_umpermute_operation.cpp new file mode 100644 index 00000000..e4d966d3 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moetoken_umpermute_operation.cpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "aclnnop/aclnn_moe_token_unpermute.h" +#include "moetoken_unpermute_operation.h" + +namespace atb_speed::common { + +MoeTokenUnpermuteOperation::MoeTokenUnpermuteOperation(const std::string &name) : AclNNOperation(name) {} + +MoeTokenUnpermuteOperation::~MoeTokenUnpermuteOperation() +{ + ATB_SPEED_LOG_DEBUG("MoeTokenPermuteOperation deconstruct"); + this->DestroyOperation(); +} + +/** + * + * @param[in] inTensorDesc: dimNum <= 8, + * @param[in] outTensorDesc: dimNum <= 8 + * @return atb::Status + */ +atb::Status MoeTokenUnpermuteOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "MoeTokenUnpermuteOperation infer shape start"); + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = inTensorDescs.at(DIM0).dtype; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM2).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + + ATB_SPEED_LOG_DEBUG(opName_ << "MoeTokenUnpermuteOperation infer shape end, inTensor0:" + << " format: " << inTensorDescs.at(DIM0).format << " dimNum: " << inTensorDescs.at(DIM0).shape.dimNum + << " dims: " << inTensorDescs.at(DIM0).shape.dims[DIM0] << ", " + << inTensorDescs.at(DIM0).shape.dims[DIM1]); + ATB_SPEED_LOG_DEBUG(opName_ << "MoeTokenUnpermuteOperation infer shape end, inTensor1:" + << " format: " << inTensorDescs.at(DIM1).format << " dimNum: " << inTensorDescs.at(DIM1).shape.dimNum + << " dims: " << inTensorDescs.at(DIM1).shape.dims[DIM0]); + ATB_SPEED_LOG_DEBUG(opName_ << "MoeTokenUnpermuteOperation infer shape end, inTensor2:" + << " format: " << inTensorDescs.at(DIM2).format << " dimNum: " << inTensorDescs.at(DIM2).shape.dimNum + << " dims: " << inTensorDescs.at(DIM2).shape.dims[DIM0] << ", " + << inTensorDescs.at(DIM2).shape.dims[DIM1]); + ATB_SPEED_LOG_DEBUG(opName_ << "MoeTokenUnpermuteOperation infer shape end, outTensor0:" + << " format: " << outTensorDescs.at(DIM0).format << " dimNum: " << outTensorDescs.at(DIM0).shape.dimNum + << " dims: " << outTensorDescs.at(DIM0).shape.dims[DIM0] << ", " + << outTensorDescs.at(DIM0).shape.dims[DIM1]); + + return atb::NO_ERROR; +} + + +uint32_t MoeTokenUnpermuteOperation::GetInputNum() const +{ + return NUM3; +} + +uint32_t MoeTokenUnpermuteOperation::GetOutputNum() const +{ + return NUM1; +} + +int MoeTokenUnpermuteOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnMoeTokenUnpermuteGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // permutedTokens + aclnnVariantPack.aclInTensors.at(1)->tensor, // sortedIndices + aclnnVariantPack.aclInTensors.at(2)->tensor, // probsOptional + false, // paddedMode + nullptr, // restoreShape + aclnnVariantPack.aclOutTensors.at(0)->tensor, // out + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end" + << ", ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int MoeTokenUnpermuteOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnMoeTokenUnpermute( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end" + << ", ret: " << ret); + return ret; +} +} // namespace atb_speed::common \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/moetoken_unpermute_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/moetoken_unpermute_operation.h new file mode 100644 index 00000000..0d42c7fe --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/moetoken_unpermute_operation.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_MOE_TOKEN_UNPERMUTE_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_MOE_TOKEN_UNPERMUTE_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" + + +/// This class defines an operator that is used to gather and reduce hidden states based on sortedIndices. +/// +/// This class makes uses of `aclnnMoeTokenUnpermuteGetWorkspaceSize` and `aclnnMoeTokenUnpermute` +/// from the AscendCL API. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// ---------------|---------------------|---------| +/// permutedTokens | float16 or bfloat16 | [m*k,h] | +/// sortedIndices | int32 | [m*k] | +/// expertsWeights | float16 or bfloat16 | [m,k] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// ----------------------------|-------|---------| +/// out | int32 | [m*k,h] | +/// Note: k is the number of experts selected for each token +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_IDX, +/// IN_EXPERT_WEIGHT, +/// OUT_MOE_MLP_RESULT, +/// }; +/// +/// atb::Node &unpermuteNode = opGraph.nodes.at(nodeId++); +/// unpermuteNode.operation = new atb_speed::common::MoeTokenUnpermuteOperation("MoeTokenUnpermuteNode"); +/// unpermuteNode.inTensorIds = {IN_INPUT, +/// IN_IDX, +/// IN_EXPERT_WEIGHT}; +/// unpermuteNode.outTensorIds = {OUT_MOE_MLP_RESULT}; +/// \endcode + +namespace atb_speed::common { + class MoeTokenUnpermuteOperation : public AclNNOperation { + public: + explicit MoeTokenUnpermuteOperation(const std::string &name); + ~MoeTokenUnpermuteOperation() override; + atb::Status InferShape( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs + ) const override; + [[nodiscard]] uint32_t GetInputNum() const override; + [[nodiscard]] uint32_t GetOutputNum() const override; + + protected: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + }; +} // namespace atb_speed::common + +#endif // ATB_SPEED_PLUGIN_ACLNN_MOE_TOKEN_UNPERMUTE_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_calculate_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_calculate_operation.cpp new file mode 100644 index 00000000..2c1d56a3 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_calculate_operation.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "acl/acl.h" +#include "aclnnop/aclnn_obfuscation_calculate.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "obfuscation_calculate_operation.h" + +namespace atb_speed { +namespace common { + +ObfuscationCalculateOperation::ObfuscationCalculateOperation( + const std::string &name, ObfuscationCalculateParam param) : AclNNOperation(name), param_(param) {} + +ObfuscationCalculateOperation:: ~ObfuscationCalculateOperation() +{ + ATB_SPEED_LOG_DEBUG("ObfuscationCalculateOperation deconstructor"); + this->DestroyOperation(); +} + +atb::Status ObfuscationCalculateOperation::InferShape( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG("ObfuscationCalculateOperation infer shape start"); + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = inTensorDescs.at(DIM0).dtype; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + for (uint32_t i = 0; i < inTensorDescs.at(0).shape.dimNum; ++i) { + outTensorDescs.at(DIM0).shape.dims[i] = inTensorDescs.at(DIM0).shape.dims[i]; + } + ATB_SPEED_LOG_DEBUG("ObfuscationCalculateOperation infer shape end"); + return 0; +} + +uint32_t ObfuscationCalculateOperation::GetInputNum() const { return NUM1; } + +uint32_t ObfuscationCalculateOperation::GetOutputNum() const { return NUM1; } + +int ObfuscationCalculateOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + uint32_t inputNum = GetInputNum(); + aclnnVariantPack.aclInTensors.resize(inputNum); + atb::Tensor atbTensor = variantPack.inTensors.at(0); + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + aclnnTensor->tensorIdx = 1; + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, + atbTensor.desc.dtype, aclnnTensor->strides.data(), 0, + atbTensor.desc.format, atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, atbTensor.deviceData); + aclnnVariantPack.aclInTensors.at(0) = aclnnTensor; + return atb::NO_ERROR; +} + +int ObfuscationCalculateOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG("aclnnObfuscationCalculateGetWorkspaceSize start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnObfuscationCalculateGetWorkspaceSize( + param_.fd, + aclnnVariantPack.aclInTensors.at(0)->tensor, + param_.hiddenSizePerRank, + param_.cmd, + aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG("aclnnObfuscationCalculateGetWorkspaceSize end, ret:" << + ret << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize << + ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + + return ret; +} + +int ObfuscationCalculateOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG("aclnnObfuscationCalculate start"); + int ret = aclnnObfuscationCalculate( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG("aclnnObfuscationCalculate end, ret:" << ret); + return ret; +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_calculate_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_calculate_operation.h new file mode 100644 index 00000000..56e9c696 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_calculate_operation.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_OBFUSCATION_CALCULATE_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_OBFUSCATION_CALCULATE_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { + +struct ObfuscationCalculateParam { + int32_t fd = 0; + int32_t cmd = 1; + uint32_t hiddenSizePerRank = 0; +}; + +class ObfuscationCalculateOperation : public AclNNOperation { +public: + explicit ObfuscationCalculateOperation(const std::string &name, ObfuscationCalculateParam param); + ~ObfuscationCalculateOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + ObfuscationCalculateParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_setup_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_setup_operation.cpp new file mode 100644 index 00000000..891d5bc6 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_setup_operation.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "acl/acl.h" +#include "aclnnop/aclnn_obfuscation_setup.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "obfuscation_setup_operation.h" + +namespace atb_speed { +namespace common { + +ObfuscationSetupOperation::ObfuscationSetupOperation(const std::string &name, + ObfuscationSetupParam param) : AclNNOperation(name), param_(param) {} + +ObfuscationSetupOperation:: ~ObfuscationSetupOperation() +{ + ATB_SPEED_LOG_DEBUG("ObfuscationSetupOperation deconstructor"); + this->DestroyOperation(); +} + +atb::Status ObfuscationSetupOperation::InferShape( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG("ObfuscationSetupOperation infer shape start"); + if (inTensorDescs.size() != 0) { + ATB_SPEED_LOG_ERROR("ObfuscationSetupOperation intensors should be 0, but get " << + inTensorDescs.size()); + return atb::ERROR_INVALID_TENSOR_SIZE; + } + outTensorDescs.at(DIM0).format = aclFormat::ACL_FORMAT_ND; + outTensorDescs.at(DIM0).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(DIM0).shape.dimNum = NUM1; + outTensorDescs.at(DIM0).shape.dims[DIM0] = NUM1; + ATB_SPEED_LOG_DEBUG("ObfuscationSetupOperation infer shape end"); + return 0; +} + +uint32_t ObfuscationSetupOperation::GetInputNum() const { return DIM0; } + +uint32_t ObfuscationSetupOperation::GetOutputNum() const { return NUM1; } + +int ObfuscationSetupOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG("aclnnObfuscationSetupGetWorkspaceSize start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnObfuscationSetupGetWorkspaceSize( + param_.fdtoClose, + param_.dataType, + param_.hiddenSizePerRank, + param_.tpRank, + 0, + 0, + param_.cmd, + param_.threadNum, + aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG("aclnnObfuscationSetupGetWorkspaceSize end, ret:" << + ret << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize << + ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + + return ret; +} + +int ObfuscationSetupOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG("aclnnObfuscationSetup start"); + int ret = aclnnObfuscationSetup( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG("aclnnObfuscationSetup end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_setup_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_setup_operation.h new file mode 100644 index 00000000..c32a49b0 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/obfuscation_setup_operation.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_OBFUSCATION_SETUP_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_OBFUSCATION_SETUP_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { + +struct ObfuscationSetupParam { + int32_t fdtoClose = 0; + int32_t dataType = 1; // 0: float32; 1: float16; 27: bfloat16 + int32_t hiddenSizePerRank = 1; + int32_t tpRank = 0; + int32_t cmd = 1; // 1: Normal mode; 3: Exit mode + int32_t threadNum = 6; // thread num in aicpu +}; + +class ObfuscationSetupOperation : public AclNNOperation { +public: + explicit ObfuscationSetupOperation(const std::string &name, ObfuscationSetupParam param); + ~ObfuscationSetupOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + ObfuscationSetupParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/prompt_flash_attention_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/prompt_flash_attention_operation.cpp new file mode 100644 index 00000000..a799db46 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/prompt_flash_attention_operation.cpp @@ -0,0 +1,168 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "aclnnop/aclnn_prompt_flash_attention_v3.h" + +#include "prompt_flash_attention_operation.h" + +namespace atb_speed { +namespace common { +PromptFlashAttentionOperation::PromptFlashAttentionOperation(const std::string &name, + AclNNFlashAttentionParam param) + : AclNNOperation(name), param_(param) +{ + ATB_SPEED_LOG_DEBUG("PromptFlashAttentionOperation, param: " << param_.ToString()); +} + +PromptFlashAttentionOperation::~PromptFlashAttentionOperation() +{ + ATB_SPEED_LOG_DEBUG("~PromptFlashAttentionOperation"); +} + +uint32_t PromptFlashAttentionOperation::GetInputNum() const +{ + return param_.needMask ? NUM6 : NUM5; +} + +uint32_t PromptFlashAttentionOperation::GetOutputNum() const +{ + return NUM1; +} + +atb::Status PromptFlashAttentionOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + outTensorDescs.at(0) = inTensorDescs.at(0); + // if input layout is BNSD_BSND, input shape is BNSD and then output shape is BSND; + // otherwise, output shape equals to input shape. + if (param_.inputLayout == "BNSD_BSND") { + outTensorDescs.at(0).shape.dims[DIM1] = inTensorDescs.at(0).shape.dims[DIM2]; + outTensorDescs.at(0).shape.dims[DIM2] = inTensorDescs.at(0).shape.dims[DIM1]; + } + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return 0; +} + +int PromptFlashAttentionOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret = 0; + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + return ret; + } + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +int PromptFlashAttentionOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + uint32_t inputNum = GetInputNum(); + aclnnVariantPack.aclInTensors.resize(inputNum); + int inTensorIdx = 0; + for (size_t i = 0; i < inputNum; i++) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + if (i == 3) { // 3 is empty tensor, skip + inTensorIdx++; + } + aclnnTensor->tensorIdx = inTensorIdx++; + if (i == inputNum - 2) { // qSeqLens is 2nd last input tensor + aclnnTensor->needUpdateTensorDataPtr = false; + ConvertTensorToSeqLengths(aclnnTensor->atbTensor, actualSeqLengths_); + } else if (i == inputNum - 1) { // kvSeqLens is the last input tensor + aclnnTensor->needUpdateTensorDataPtr = false; + ConvertTensorToSeqLengths(aclnnTensor->atbTensor, actualSeqLengthsKv_); + } else { + aclnnTensor->needUpdateTensorDataPtr = true; + atb::Tensor atbTensor = variantPack.inTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor(atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, + atbTensor.desc.dtype, aclnnTensor->strides.data(), 0, atbTensor.desc.format, atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, atbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor index " << i << " create fail"); + return atb::ERROR_INTERNAL_ERROR; + } + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int PromptFlashAttentionOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); i++) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(variantPack.outTensors.at(i)); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor(squeezedAtbTensor.desc.shape.dims, squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.desc.dtype, aclnnTensor->strides.data(), 0, squeezedAtbTensor.desc.format, + squeezedAtbTensor.desc.shape.dims, squeezedAtbTensor.desc.shape.dimNum, squeezedAtbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " OutTensor index " << i << " create fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int PromptFlashAttentionOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << "GetWorkspaceSize start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclTensor *query = aclnnVariantPack.aclInTensors.at(0)->tensor; + aclTensor *key = aclnnVariantPack.aclInTensors.at(1)->tensor; + aclTensor *value = aclnnVariantPack.aclInTensors.at(2)->tensor; + aclTensor *pseShift = nullptr; + aclTensor *attenMask = param_.needMask ? aclnnVariantPack.aclInTensors.at(3)->tensor : nullptr; + + int ret = aclnnPromptFlashAttentionV3GetWorkspaceSize(query, key, value, pseShift, attenMask, actualSeqLengths_, + actualSeqLengthsKv_, nullptr, nullptr, nullptr, nullptr, nullptr, param_.numHeads, param_.scaleValue, + param_.preTokens, param_.nextTokens, const_cast(param_.inputLayout.c_str()), param_.numKeyValueHeads, + param_.sparseMode, param_.innerPrecise, aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, &this->aclnnOpCache_->aclExecutor); + + ATB_SPEED_LOG_DEBUG(opName_ << " end, ret:" << ret << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize << + ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int PromptFlashAttentionOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + return aclnnPromptFlashAttentionV3(workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, + stream); +} +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/prompt_flash_attention_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/prompt_flash_attention_operation.h new file mode 100644 index 00000000..9de939a2 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/prompt_flash_attention_operation.h @@ -0,0 +1,135 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_FLASH_ATTENTION_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_FLASH_ATTENTION_OPERATION_H +#include +#include +#include + +#include "operations/aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { +/// A struct defining `aclnnPromptFlashAttentionV3` operation's parameters +struct AclNNFlashAttentionParam { + /// A flag indicating whether the attention uses mask + bool needMask = false; + /// The number of query's heads + int64_t numHeads; + /// The scaling value + double scaleValue = 1.0; + /// The number of previous tokens related to attention's calculation, default to 214748647 + int64_t preTokens = 214748647; + /// The number of post tokens related to attention's calculation, default to 65535 + int64_t nextTokens = 65535; + /// The parameter indicating the layout of q, k, v input tensors (BSH、BSND、BNSD、BNSD_BSND) + std::string inputLayout = "BSND"; + /// The number of key and value's heads + int64_t numKeyValueHeads; + /// The parameter for the sparse mode + /// 0: default mask, support no mask or full mask input (S1*S2) + /// 1: allMask, only support full mask (S1*S2) + /// 2: leftUpCausal mask, support improved mask (2048*2048) + /// 3: rightDownCausal mask, support improved mask (2048*2048) + /// 4: band mode mask, support improved mask (2048*2048) + int64_t sparseMode = 0; + /// The parameter for high precision/performance mode, default to 1 + /// 0: high precision mode without correction + /// 1: high performance mode without correction + /// 2: high precision mode with correction + /// 3: high performance mode with correction + int64_t innerPrecise = 1; + std::string ToString() const + { + std::ostringstream oss; + oss << "AclNNFlashAttentionParam {" << std::endl; + oss << " needMask: " << needMask << std::endl; + oss << " numHeads: " << numHeads << std::endl; + oss << " scaleValue: " << scaleValue << std::endl; + oss << " preTokens: " << preTokens << std::endl; + oss << " nextTokens: " << nextTokens << std::endl; + oss << " inputLayout: " << inputLayout << std::endl; + oss << " numKeyValueHeads: " << numKeyValueHeads << std::endl; + oss << " sparseMode: " << sparseMode << std::endl; + oss << " innerPrecise: " << innerPrecise << std::endl; + oss << "}"; + return oss.str(); + } +}; + +/// This class defines an operator that calculates the flash attention in encoding stage. +/// +/// This class class makes uses of `aclnnPromptFlashAttentionV3GetWorkspaceSize` and +/// `aclnnPromptFlashAttentionV3` from the AscendCL API. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// --------------------------|---------------------------|------------------------------------------| +/// query | float16, bfloat16 or int8 | [B,S,H], [B,S,N,D] or [B,N,S,D] | +/// key | float16, bfloat16 or int8 | [B,S,H], [B,S,N,D] or [B,N,S,D] | +/// value | float16, bfloat16 or int8 | [B,S,H], [B,S,N,D] or [B,N,S,D] | +/// attenMaskOptional | bool, int8, uint8 | [Q_S,KV_S], [B,Q_S,KV_S] or [1,Q_S,KV_S] | +/// actualSeqLengthsOptional | int64 | [B] | +/// actualSeqLengthsKvOptional| int64 | [B] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// -------------|---------------------------|---------------------------------| +/// output | float16, bfloat16 or int8 | [B,S,H], [B,S,N,D] or [B,N,S,D] | +/// +/// Note: B > batch; S > seqlen; H > hidden size; N > head-num; D > head-dim +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// QUERY, +/// KEY, +/// VALUE, +/// Q_SEQ_LEN, +/// KV_SEQ_LEN, +/// OUT, +/// }; +/// +/// atb::Node &promptFANode = opGraph.nodes.at(nodeId++); +/// argsortNode.operation = new atb_speed::common::PromptFlashAttention("PromptFlashAttentionNode"); +/// argsortNode.inTensorIds = {QUERY, KEY, VALUE, Q_SEQ_LEN, KV_SEQ_LEN}; +/// argsortNode.outTensorIds = {OUT}; +/// \endcode + +class PromptFlashAttentionOperation : public AclNNOperation { +public: + explicit PromptFlashAttentionOperation(const std::string &name, AclNNFlashAttentionParam param); + ~PromptFlashAttentionOperation() override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + +protected: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + +private: + AclNNFlashAttentionParam param_; + aclIntArray *actualSeqLengths_ = nullptr; + aclIntArray *actualSeqLengthsKv_ = nullptr; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/quant_batch_matmul_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/quant_batch_matmul_operation.cpp new file mode 100644 index 00000000..8f4c3e50 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/quant_batch_matmul_operation.cpp @@ -0,0 +1,186 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "aclnnop/aclnn_weight_quant_batch_matmul_v2.h" +#include "quant_batch_matmul_operation.h" + +namespace atb_speed { +namespace common { + +QuantBatchMatmulOperation::QuantBatchMatmulOperation( + const std::string &name, + AclNNWeightQuantBatchMatmulParam param) : AclNNOperation(name), param_(param) {} + +QuantBatchMatmulOperation::~QuantBatchMatmulOperation() +{ + ATB_SPEED_LOG_DEBUG("QuantBatchMatmulOperation deconstructor"); + this->DestroyOperation(); +} + +atb::Status QuantBatchMatmulOperation::InferShape( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum; + + int nDim = param_.transposeB ? DIM0 : DIM1; + if (inTensorDescs.at(0).shape.dimNum == DIM3) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 3] CHECK " << opName_ << " inputs shape: [input0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0] << ", " + << inTensorDescs.at(DIM0).shape.dims[DIM1] << ", " << inTensorDescs.at(DIM0).shape.dims[DIM2]); + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 3] CHECK " << opName_ << " inputs shape: [input1]" + << inTensorDescs.at(DIM1).shape.dims[DIM0] << ", " << inTensorDescs.at(DIM1).shape.dims[DIM1]); + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + outTensorDescs.at(DIM0).shape.dims[DIM2] = inTensorDescs.at(DIM3).shape.dims[nDim]; + } else if (inTensorDescs.at(0).shape.dimNum == DIM2) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 2] CHECK " << opName_ << " inputs shape: [input0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0] << ", " << inTensorDescs.at(DIM0).shape.dims[DIM1]); + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 2] CHECK " << opName_ << " inputs shape: [input1]" + << inTensorDescs.at(DIM1).shape.dims[DIM0] << ", " << inTensorDescs.at(DIM1).shape.dims[DIM1]); + outTensorDescs.at(0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(0).shape.dims[DIM1] = inTensorDescs.at(DIM3).shape.dims[nDim]; + } else { + ATB_SPEED_LOG_ERROR(opName_ << " invalid dim num:" << inTensorDescs.at(DIM0).shape.dimNum); + } + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return 0; +} + +uint32_t QuantBatchMatmulOperation::GetInputNum() const { return param_.hasBias ? NUM5 : NUM4; } + +uint32_t QuantBatchMatmulOperation::GetOutputNum() const { return NUM1; } + +atb::Dims QuantBatchMatmulOperation::GetWeightStorageShape(const atb::TensorDesc atbTensorDesc) +{ + if (atbTensorDesc.format != ACL_FORMAT_FRACTAL_NZ) { + // nd格式下,storageShape和originalShape一致 + return atbTensorDesc.shape; + } + // nz格式 + atb::Dims storageTensorDims = atbTensorDesc.shape; + storageTensorDims.dimNum = 4; // 4: 4维 + if (param_.transposeB) { + uint32_t kPadding = 16; + uint32_t nPadding = 32; + // (n, k) => (k1, n1, n0, k0) + storageTensorDims.dims[0] = 1 + ((atbTensorDesc.shape.dims[1] - 1) / kPadding); + storageTensorDims.dims[1] = 1 + ((atbTensorDesc.shape.dims[0] - 1) / nPadding); + storageTensorDims.dims[2] = nPadding; // 2: 维度 + storageTensorDims.dims[3] = kPadding; // 3: 维度 + } else { + uint32_t kPadding = 32; + uint32_t nPadding = 16; + // (k, n) => (n1, k1, k0, n0) + storageTensorDims.dims[0] = 1 + ((atbTensorDesc.shape.dims[1] - 1) / nPadding); + storageTensorDims.dims[1] = 1 + ((atbTensorDesc.shape.dims[0] - 1) / kPadding); + storageTensorDims.dims[2] = kPadding; // 2: 维度 + storageTensorDims.dims[3] = nPadding; // 3: 维度 + } + return storageTensorDims; +} + +atb::Status QuantBatchMatmulOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = (i == 4) ? (i + 2) : i; // 4, 2: bias在aclExecutor中的idx为6 + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor preprocessedATBTensor = this->PreprocessATBInTensor(variantPack.inTensors.at(i), i); + if ((i == 1) || (i == 2) || (i == 3)) { // 1, 2, 3: weight, weight_scale, weight_offset + if (preprocessedATBTensor.desc.shape.dimNum != NUM2) { + ATB_SPEED_LOG_ERROR(this->opName_ << " weight tensor dimNum after combine batch size " + << "and seq len axis should be 2, but got " << preprocessedATBTensor.desc.shape.dimNum); + return atb::ERROR_INTERNAL_ERROR; + } + // StorageShape + atb::Dims storageDims = preprocessedATBTensor.desc.shape; + if (i == 1) { // weight的storageShape会根据NZ和ND格式而有所不同 + storageDims = GetWeightStorageShape(preprocessedATBTensor.desc); + } + // ViewShape and Stride + atb::Dims viewDims = preprocessedATBTensor.desc.shape; + if (IsA2() && this->param_.transposeB) { + aclnnTensor->strides = GetTransposeTensorStride(viewDims); + viewDims.dims[0] = preprocessedATBTensor.desc.shape.dims[1]; + viewDims.dims[1] = preprocessedATBTensor.desc.shape.dims[0]; + } else { + aclnnTensor->strides = GetCopyTensorStride(viewDims); + } + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor( + viewDims, storageDims, preprocessedATBTensor, aclnnTensor)); + } else { + aclnnTensor->strides = GetCopyTensorStride(preprocessedATBTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(preprocessedATBTensor.desc.shape, + preprocessedATBTensor.desc.shape, preprocessedATBTensor, aclnnTensor)); + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +atb::Status QuantBatchMatmulOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(NUM1); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(variantPack.outTensors.at(i)); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(squeezedAtbTensor.desc.shape, + squeezedAtbTensor.desc.shape, squeezedAtbTensor, aclnnTensor)); + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int QuantBatchMatmulOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnWeightQuantBatchMatmulV2GetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // 0: x + aclnnVariantPack.aclInTensors.at(1)->tensor, // 1: weight + aclnnVariantPack.aclInTensors.at(2)->tensor, // 2: antiquantScale + aclnnVariantPack.aclInTensors.at(3)->tensor, nullptr, nullptr, // 3: antiquantOffset + param_.hasBias ? aclnnVariantPack.aclInTensors.at(4)->tensor : nullptr, // 4: bias + param_.quantGroupSize, aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " end, ret:" + << ret << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int QuantBatchMatmulOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + return aclnnWeightQuantBatchMatmulV2( + workspace, this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, stream); +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/quant_batch_matmul_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/quant_batch_matmul_operation.h new file mode 100644 index 00000000..f4415bec --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/quant_batch_matmul_operation.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_QAUNT_BATCH_MATMUL_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_QAUNT_BATCH_MATMUL_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { +/// A struct defines the parameter of `W8A16Operation` and `W4A16Operation`. +struct AclNNWeightQuantBatchMatmulParam { + /// A flag indicating whether the matmul operation includes a bias tensor. + bool hasBias = false; + /// The group size used for dequantizing the weight tensor in the per-group quantization approach. + int quantGroupSize = 0; + /// A flag indicating whether the second matrix in the matmul operation is transposed. + bool transposeB = false; +}; + +/// This class defines a matrix operation that supports per-channel and per-group weight quantization +/// while keeping activations in floating-point format. +/// +/// This class makes use of `aclnnQuantMatmulV4GetWorkspaceSize` and `aclnnQuantMatmulV4` from the AscendCL API. +/// This class contains a virtual function called `PreprocessATBInTensor`, which cannot be invoked directly. +/// The `W8A16Operation` and `W4A16Operation` classes inherit from this base class +/// and implement the `PreprocessATBInTensor` function to handle various tensor data types. +class QuantBatchMatmulOperation : public AclNNOperation { +public: + explicit QuantBatchMatmulOperation(const std::string &name, AclNNWeightQuantBatchMatmulParam param); + ~QuantBatchMatmulOperation() override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + +protected: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + virtual atb::Tensor PreprocessATBInTensor(atb::Tensor atbTensor, int index) = 0; + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Dims GetWeightStorageShape(const atb::TensorDesc atbTensorDesc); + +private: + AclNNWeightQuantBatchMatmulParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/quant_gmm_dequant_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/quant_gmm_dequant_operation.cpp new file mode 100644 index 00000000..8582415a --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/quant_gmm_dequant_operation.cpp @@ -0,0 +1,198 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include + +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "aclnnop/aclnn_quant_grouped_matmul_dequant.h" +#include "quant_gmm_dequant_operation.h" + +namespace atb_speed { +namespace common { + + +QuantGMMDequantOperation::QuantGMMDequantOperation( + const std::string &name, + AclNNQuantGMMDequantParam param) : AclNNOperation(name), param_(param) {} + + +QuantGMMDequantOperation::~QuantGMMDequantOperation() +{ + ATB_SPEED_LOG_DEBUG("QuantGMMDequantOperation deconstructor"); + this->DestroyOperation(); +} + +atb::Status QuantGMMDequantOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "QuantGMMDequantOperation infer shape start"); + outTensorDescs.at(0).format = inTensorDescs.at(0).format; // FORMAT_ND + outTensorDescs.at(0).dtype = param_.outDataType; // ACL_FLOAT16; + // in1 (8192, 7168); in2 [64, 2048, 7168]; out0 (8192, 2048) + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum; // dimNum = 2 + outTensorDescs.at(0).shape.dims[DIM0] = inTensorDescs.at(0).shape.dims[DIM0]; // DIM0 = inTensor0.DIM0 + outTensorDescs.at(0).shape.dims[DIM1] = inTensorDescs.at(NUM1).shape.dims[param_.transposeB ? DIM1 : DIM2]; + + for (uint64_t i = 0; i < GetInputNum(); ++i) { + ATB_SPEED_LOG_DEBUG(opName_ << " QuantGMMDequantOperation infer shape end" << + " format: " << inTensorDescs.at(i).format << + " dtype: " << inTensorDescs.at(i).dtype << + " dimNum: " << inTensorDescs.at(i).shape.dimNum << + " dim0: " << inTensorDescs.at(i).shape.dims[0] << + " dim1: " << inTensorDescs.at(i).shape.dims[1] + ); + if (i == 1) { + ATB_SPEED_LOG_DEBUG(" dim2: " << inTensorDescs.at(1).shape.dims[2]); + } + } + + ATB_SPEED_LOG_DEBUG(opName_ << " QuantGMMDequantOperation infer shape end" << + " format: " << outTensorDescs.at(0).format << + " dtype: " << outTensorDescs.at(0).dtype << + " dimNum: " << outTensorDescs.at(0).shape.dimNum << + " dims: " << outTensorDescs.at(0).shape.dims[0] << + " dims: " << outTensorDescs.at(0).shape.dims[1]); + return 0; +} + +uint32_t QuantGMMDequantOperation::GetInputNum() const +{ + return NUM4; +} + +uint32_t QuantGMMDequantOperation::GetOutputNum() const +{ + return NUM1; +} + +atb::Dims QuantGMMDequantOperation::GetWeightStorageShape(const atb::TensorDesc atbTensorDesc) const +{ + atb::Dims storageTensorDims = atbTensorDesc.shape; // ND格式下,storageShape和originalShape一致 + ATB_SPEED_LOG_DEBUG(opName_ << " GetWeightStorageShape inWeightTensor dim: " << + atbTensorDesc.shape.dims[0] << ", " << atbTensorDesc.shape.dims[1] << ", " << atbTensorDesc.shape.dims[2] + ); + + if (atbTensorDesc.format == ACL_FORMAT_FRACTAL_NZ) { + // nz格式 + storageTensorDims.dimNum = 5; // 5维 + // (group_size, n, k) => (group_size, k / 32, n / 16, 16, 32) + storageTensorDims.dims[0] = atbTensorDesc.shape.dims[0]; + storageTensorDims.dims[3] = 16; // 3, 16:NZ格式要求 + storageTensorDims.dims[4] = 32; // 4, 16:NZ格式要求 + if (param_.transposeB) { + storageTensorDims.dims[1] = ((atbTensorDesc.shape.dims[2] + 32 - 1) / 32); // 1, 32:1: 维度, 32: padding大小 + storageTensorDims.dims[2] = ((atbTensorDesc.shape.dims[1] + 16 - 1) / 16); // 2, 16:1: 维度, 16: padding大小 + } else { + storageTensorDims.dims[1] = ((atbTensorDesc.shape.dims[1] + 32 - 1) / 32); // 1, 32:1: 维度, 32: padding大小 + storageTensorDims.dims[2] = ((atbTensorDesc.shape.dims[2] + 16 - 1) / 16); // 2, 16:1: 维度, 16: padding大小 + } + } + ATB_SPEED_LOG_DEBUG(opName_ << " GetWeightStorageShape storageTensorDims dims: " << + storageTensorDims.dims[0] << ", " << storageTensorDims.dims[1] << ", " << storageTensorDims.dims[2] << ", " << + storageTensorDims.dims[3] << ", " << storageTensorDims.dims[4] + ); + return storageTensorDims; +} + +atb::Status QuantGMMDequantOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->tensorListidx = AclNNTensor::notInTensorList; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor atbTensor = variantPack.inTensors.at(i); + // StorageShape + if (i == 1) { + atb::Tensor storageATBTensor = variantPack.inTensors.at(i); + atb::Dims storageTensorDims = GetWeightStorageShape(storageATBTensor.desc); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(storageTensorDims, storageTensorDims, + atbTensor, aclnnTensor)); // gmm 是根据 viewDims.dimNum 判断的 + } else { + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, + atbTensor, aclnnTensor)); + } + + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +atb::Status QuantGMMDequantOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->tensorListidx = AclNNTensor::notInTensorList; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor atbTensor = variantPack.outTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, + atbTensor, aclnnTensor)); + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int QuantGMMDequantOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + + int ret = aclnnQuantGroupedMatmulDequantGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // 0: x + aclnnVariantPack.aclInTensors.at(NUM1)->tensor, // 1: weight + aclnnVariantPack.aclInTensors.at(NUM2)->tensor, // 2: weightScale + aclnnVariantPack.aclInTensors.at(NUM3)->tensor, // 3: groupList, int64 + nullptr, nullptr, nullptr, nullptr, // 4: bias; 5: xScale; 6: xOffset; 7: smoothScale; + param_.quantMode.data(), // 8: xQuantMode + param_.transposeB, // 9: transposeWeight + aclnnVariantPack.aclOutTensors.at(0)->tensor, // 10: out + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + + +int QuantGMMDequantOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " QuantGMMDequantOperation start"); + int ret = aclnnQuantGroupedMatmulDequant(workspace, this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " QuantGMMDequantOperation end, ret:" << ret); + return 0; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/quant_gmm_dequant_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/quant_gmm_dequant_operation.h new file mode 100644 index 00000000..1cc3e78f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/quant_gmm_dequant_operation.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_QUANT_GMM_DEQUANT_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_QUANT_GMM_DEQUANT_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { +struct AclNNQuantGMMDequantParam { + std::string quantMode = "pertoken"; // "pertoken" or "pertensor" + bool transposeB = false; + aclDataType outDataType = ACL_INT64; +}; + +class QuantGMMDequantOperation : public AclNNOperation { +public: + explicit QuantGMMDequantOperation(const std::string &name, AclNNQuantGMMDequantParam param); + ~QuantGMMDequantOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Dims GetWeightStorageShape(const atb::TensorDesc atbTensorDesc) const; + + AclNNQuantGMMDequantParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/repeat_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/repeat_operation.cpp new file mode 100644 index 00000000..5db4fdb5 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/repeat_operation.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "aclnnop/aclnn_repeat.h" +#include "repeat_operation.h" + +namespace atb_speed::common { + +RepeatOperation::RepeatOperation( + const std::string &name, + atb_speed::common::AclNNRepeatParam param +) : AclNNOperation(name), param_(param) +{ + this->opName_ = name; + this->param_ = param; +} + +RepeatOperation::~RepeatOperation() +{ + ATB_SPEED_LOG_DEBUG("RepeatOperation deconstruct"); + this->DestroyOperation(); +} + +/** + * + * @param[in] inTensorDesc: dimNum <= 8, + * @param[in] outTensorDesc: dimNum <= 8 + * @return atb::Status + */ +atb::Status RepeatOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "RepeatOperation infer shape start"); + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum; + for (uint64_t i = 0; i < inTensorDescs.at(0).shape.dimNum; ++i) { + outTensorDescs.at(0).shape.dims[i] = inTensorDescs.at(0).shape.dims[i] * param_.repeatsArray[i]; + } + + ATB_SPEED_LOG_DEBUG(opName_ << "RepeatOperation infer shape end" + << " format: " << inTensorDescs.at(0).format << " dimNum: " << inTensorDescs.at(0).shape.dimNum + << " dims: " << inTensorDescs.at(0).shape.dims[0]); + return 0; +} + + +uint32_t RepeatOperation::GetInputNum() const +{ + return DIM1; +} + +uint32_t RepeatOperation::GetOutputNum() const +{ + return DIM1; +} + +int RepeatOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + // 创建normalizedShape aclIntArray + aclIntArray *repeats = aclCreateIntArray(param_.repeatsArray.data(), param_.repeatsArray.size()); + int ret = aclnnRepeatGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // input + repeats, // repeatShape + aclnnVariantPack.aclOutTensors.at(0)->tensor, // out + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end" + << ", ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int RepeatOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnRepeat( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end" + << ", ret: " << ret); + return ret; +} +} // namespace atb_speed::common \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/repeat_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/repeat_operation.h new file mode 100644 index 00000000..eb2dacba --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/repeat_operation.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_REPEAT_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_REPEAT_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed::common { +struct AclNNRepeatParam { + std::vector repeatsArray; +}; + +/// This class defines an repeat operator. +/// +/// This class makes uses of `aclnnRepeatGetWorkspaceSize` and `aclnnRepeat` +/// from the AscendCL API. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// ---------------|---------------------|---------| +/// input | float16 or bfloat16 | [m,h] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// ---------------|---------------------|---------| +/// out | float16 or bfloat16 |[m*k,h*n]| +/// Note: k, n are the repetition times. +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// OUT, +/// }; +/// +/// atb::Node &repeatNode = opGraph.nodes.at(nodeId++); +/// atb_speed::common::AclNNRepeatParam repeatParam; +/// repeatParam.repeatsArray = param.repeatsArray; +/// repeatNode.operation = new atb_speed::common::RepeatOperation("RepeatOperation", repeatParam); +/// repeatNode.inTensorIds = {IN_INPUT}; +/// repeatNode.outTensorIds = {OUT}; +/// \endcode + +class RepeatOperation : public AclNNOperation { +public: + explicit RepeatOperation(const std::string &name, AclNNRepeatParam param); + ~RepeatOperation() override; + atb::Status InferShape( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs + ) const override; + [[nodiscard]] uint32_t GetInputNum() const override; + [[nodiscard]] uint32_t GetOutputNum() const override; + +protected: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + +private: + AclNNRepeatParam param_; + std::string opName_; +}; +} // namespace atb_speed::common + +#endif // ATB_SPEED_PLUGIN_ACLNN_REPEAT_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/rms_norm_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/rms_norm_operation.cpp new file mode 100644 index 00000000..814c5205 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/rms_norm_operation.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + */ +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "aclnnop/aclnn_rms_norm.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "rms_norm_operation.h" + +namespace atb_speed { +namespace common { + +RmsNormOperation::RmsNormOperation(const std::string &name, float epsilon) : AclNNOperation(name) +{ + this->opName_ = name; + this->epsilon = epsilon; +} + +atb::Status RmsNormOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + for (size_t i = 0; i < outTensorDescs.size(); i++) { + outTensorDescs.at(i).format = inTensorDescs.at(0).format; + if (i == NUM1) { + outTensorDescs.at(i).dtype = aclDataType::ACL_FLOAT; + } else { + outTensorDescs.at(i).dtype = inTensorDescs.at(0).dtype; + } + + outTensorDescs.at(i).shape.dimNum = inTensorDescs.at(0).shape.dimNum; + + if (inTensorDescs.at(0).shape.dimNum == DIM3) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 3] CHECK aclnn rmsnorm inputs shape: [input0]" + << inTensorDescs.at(0).shape.dims[DIM0] << ", " << inTensorDescs.at(0).shape.dims[DIM1] + << ", " << inTensorDescs.at(0).shape.dims[DIM2]); + outTensorDescs.at(i).shape.dims[DIM0] = inTensorDescs.at(0).shape.dims[DIM0]; + outTensorDescs.at(i).shape.dims[DIM1] = inTensorDescs.at(0).shape.dims[DIM1]; + outTensorDescs.at(i).shape.dims[DIM2] = inTensorDescs.at(0).shape.dims[DIM2]; + } else if (inTensorDescs.at(0).shape.dimNum == DIM2) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 2] CHECK aclnn rmsnorm inputs shape: [input0]" + << inTensorDescs.at(0).shape.dims[DIM0] << ", " + << inTensorDescs.at(0).shape.dims[DIM1]); + outTensorDescs.at(i).shape.dims[DIM0] = inTensorDescs.at(0).shape.dims[DIM0]; + outTensorDescs.at(i).shape.dims[DIM1] = inTensorDescs.at(0).shape.dims[DIM1]; + } else { + ATB_SPEED_LOG_ERROR(opName_ << " invalid dim num:" << inTensorDescs.at(DIM0).shape.dimNum); + } + } + + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return 0; +} + +uint32_t RmsNormOperation::GetInputNum() const { return NUM2; } + +uint32_t RmsNormOperation::GetOutputNum() const { return NUM2; } + +int RmsNormOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnRmsNormGetWorkspaceSize start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnRmsNormGetWorkspaceSize(aclnnVariantPack.aclInTensors.at(0)->tensor, + aclnnVariantPack.aclInTensors.at(1)->tensor, + this->epsilon, + aclnnVariantPack.aclOutTensors.at(0)->tensor, + aclnnVariantPack.aclOutTensors.at(1)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnRmsNormGetWorkspaceSize end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize << ", aclExecutor:" + << this->aclnnOpCache_->aclExecutor); + + return ret; +} + +int RmsNormOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnRmsNorm start"); + int ret = aclnnRmsNorm(workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnRmsNorm end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/rms_norm_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/rms_norm_operation.h new file mode 100644 index 00000000..57916c73 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/rms_norm_operation.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_RMSNORM_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_RMSNORM_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { +class RmsNormOperation : public AclNNOperation { +public: + explicit RmsNormOperation(const std::string &name, float epsilon); + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + float epsilon = 1e-5; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/scatter_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/scatter_operation.cpp new file mode 100644 index 00000000..1e8ec107 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/scatter_operation.cpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "operations/aclnn/core/acl_nn_operation.h" +#include "aclnnop/aclnn_scatter.h" +#include "scatter_operation.h" + +namespace atb_speed { +namespace common { + +ScatterOperation::ScatterOperation( + const std::string &name, + AclNNScatterParam param, bool isInplace) : AclNNOperation(name), param_(param), isInplace_(isInplace) {} + +ScatterOperation::~ScatterOperation() +{ + ATB_SPEED_LOG_DEBUG("ScatterOperation deconstructor"); + this->DestroyOperation(); +} + +constexpr int MAX_DIMENSION = 8; + +atb::Status ScatterOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + + if (inTensorDescs.at(0).shape.dimNum > MAX_DIMENSION) { + ATB_SPEED_LOG_ERROR(opName_ << " self tensor dim num exceeds limit"); + return atb::ERROR_INVALID_PARAM; + } + + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum; + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + + for (uint32_t i = 0; i < inTensorDescs.at(0).shape.dimNum; i++) { + outTensorDescs.at(0).shape.dims[i] = inTensorDescs.at(0).shape.dims[i]; + } + + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return atb::NO_ERROR; +} + +uint32_t ScatterOperation::GetInputNum() const { return 3; } + +uint32_t ScatterOperation::GetOutputNum() const { return 1; } + +int ScatterOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + + int ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR(opName_ << " CreateAclNNInTensorVariantPack failed"); + return ret; + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR(opName_ << " CreateAclNNOutTensorVariantPack failed"); + return ret; + } + + aclnnOpCache_->executorRepeatable = false; + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +atb::Status ScatterOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { // self, index, src + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + + aclnnTensor->strides = GetCopyTensorStride(aclnnTensor->atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + aclnnTensor->atbTensor.desc.shape.dims, aclnnTensor->atbTensor.desc.shape.dimNum, + aclnnTensor->atbTensor.desc.dtype, aclnnTensor->strides.data(), 0, + aclnnTensor->atbTensor.desc.format, aclnnTensor->atbTensor.desc.shape.dims, + aclnnTensor->atbTensor.desc.shape.dimNum, aclnnTensor->atbTensor.deviceData); + + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " failed"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +atb::Status ScatterOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = 0; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(0); + + aclnnTensor->strides = GetCopyTensorStride(aclnnTensor->atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + aclnnTensor->atbTensor.desc.shape.dims, aclnnTensor->atbTensor.desc.shape.dimNum, + aclnnTensor->atbTensor.desc.dtype, aclnnTensor->strides.data(), 0, + aclnnTensor->atbTensor.desc.format, aclnnTensor->atbTensor.desc.shape.dims, + aclnnTensor->atbTensor.desc.shape.dimNum, aclnnTensor->atbTensor.deviceData); + + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " OutTensor aclCreateTensor failed"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclOutTensors[0] = aclnnTensor; + return atb::NO_ERROR; +} + +int ScatterOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + + int64_t reduceType = static_cast(param_.reduce); + if (!isInplace_) { + int ret = aclnnScatterGetWorkspaceSize( + aclnnVariantPack.aclInTensors[0]->tensor, // self + param_.dim, // scatter dim + aclnnVariantPack.aclInTensors[1]->tensor, // index + aclnnVariantPack.aclInTensors[2]->tensor, // src + reduceType, // reduce type + aclnnVariantPack.aclOutTensors[0]->tensor, // out + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + if (ret != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR(opName_ << " aclnnScatterGetWorkspaceSize failed with error code: " << ret); + return ret; + } + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnScatter SetAclNNWorkspaceExecutor end, ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor); + return ret; + } else { + int ret = aclnnInplaceScatterGetWorkspaceSize( + aclnnVariantPack.aclInTensors[0]->tensor, // self + param_.dim, // scatter dim + aclnnVariantPack.aclInTensors[1]->tensor, // index + aclnnVariantPack.aclInTensors[2]->tensor, // src + reduceType, // reduce type + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + if (ret != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR(opName_ << "aclnnInplaceScatterGetWorkspaceSize failed with error code: " << ret); + return ret; + } + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnScatterInplace SetAclNNWorkspaceExecutor end, ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor); + return ret; + } +} + +int ScatterOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + if (!isInplace_) { + int ret = aclnnScatter( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + if (ret != 0) { + ATB_SPEED_LOG_ERROR("aclnnScatter ExecuteAclNNOp failed, ret: " << ret); + } + return ret; + } else { + int ret = aclnnInplaceScatter( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + if (ret != 0) { + ATB_SPEED_LOG_ERROR("aclnnInplaceScatter ExecuteAclNNOp failed, ret: " << ret); + } + return ret; + } +} + +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/scatter_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/scatter_operation.h new file mode 100644 index 00000000..381f9635 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/scatter_operation.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_SCATTER_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_SCATTER_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" + +namespace atb_speed { +namespace common { + +enum class ReduceType { + REPLACE = 0, + ADD = 1, + MULTIPLY = 2 +}; + +struct AclNNScatterParam { + int64_t dim = 0; + ReduceType reduce = ReduceType::REPLACE; +}; + +class ScatterOperation : public AclNNOperation { +public: + explicit ScatterOperation(const std::string &name, AclNNScatterParam param, bool isInplace); + ~ScatterOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + +private: + AclNNScatterParam param_; + bool isInplace_; +}; +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_SCATTER_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/sigmoid_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/sigmoid_operation.cpp new file mode 100644 index 00000000..72c57ffa --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/sigmoid_operation.cpp @@ -0,0 +1,125 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sigmoid_operation.h" +#include "acl/acl.h" +#include "aclnnop/aclnn_sigmoid.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" + +namespace atb_speed::common { + +SigmoidOperation::SigmoidOperation( + const std::string &name) : AclNNOperation(name) {} + +SigmoidOperation::~SigmoidOperation() +{ + ATB_SPEED_LOG_DEBUG("SigmoidOperation deconstructor"); + this->DestroyOperation(); +} + +atb::Status SigmoidOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "SigmoidOperation infer shape start"); + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum; + for (uint64_t i = 0; i < inTensorDescs.at(0).shape.dimNum; ++i) { + outTensorDescs.at(0).shape.dims[i] = inTensorDescs.at(0).shape.dims[i]; + } + + ATB_SPEED_LOG_DEBUG(opName_ << "SigmoidOperation infer shape end" + << " format: " << inTensorDescs.at(0).format << " dimNum: " << inTensorDescs.at(0).shape.dimNum + << " dims: " << inTensorDescs.at(0).shape.dims[0]); + return 0; +} + +uint32_t SigmoidOperation::GetInputNum() const +{ + return NUM1; +} + + +uint32_t SigmoidOperation::GetOutputNum() const +{ + return NUM1; +} + +atb::Status SigmoidOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->tensorListidx = AclNNTensor::notInTensorList; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor atbTensor = variantPack.inTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, + atbTensor, aclnnTensor)); + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +atb::Status SigmoidOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->tensorListidx = AclNNTensor::notInTensorList; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor atbTensor = variantPack.outTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, + atbTensor, aclnnTensor)); + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int SigmoidOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnSigmoidGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, + aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int SigmoidOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SigmoidOperation start"); + int ret = aclnnSigmoid( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " SigmoidOperation end, ret:" << ret); + return ret; +} +} // namespace atb_speed::common diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/sigmoid_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/sigmoid_operation.h new file mode 100644 index 00000000..1975f73f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/sigmoid_operation.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_SIGMOID_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_SIGMOID_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" + + +/// This class defines an sigmoid operator. +/// +/// This class makes uses of `aclnnSigmoidGetWorkspaceSize` and `aclnnSigmoid` +/// from the AscendCL API. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// ---------------|---------------------|---------| +/// input | float16 or bfloat16 | [m,h] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// ---------------|---------------------|---------| +/// out | float16 or bfloat16 | [m,h] | +/// +/// Example: +/// \code +/// enum InTensorIdx : uint32_t { +/// IN_INPUT = 0, +/// }; +/// +/// enum OutTensorIdx : uint32_t { +/// OUT = 0 +/// }; +/// +/// atb::Node &sigmoidNode = opGraph.nodes.at(nodeId++); +/// sigmoidNode.operation = new atb_speed::common::SigmoidOperation("SigmoidOperation"); +/// sigmoidNode.inTensorIds = {IN_INPUT}; +/// sigmoidNode.outTensorIds = {OUT}; +/// \endcode + +namespace atb_speed { +namespace common { + +class SigmoidOperation : public AclNNOperation { +public: + explicit SigmoidOperation(const std::string &name); + ~SigmoidOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; +}; + +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_SIGMOID_OPERATION_H diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/split_with_size_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/split_with_size_operation.cpp new file mode 100644 index 00000000..82c2d823 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/split_with_size_operation.cpp @@ -0,0 +1,180 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" +#include "split_with_size_operation.h" + +namespace atb_speed { +namespace common { + +SplitWithSizeOperation::SplitWithSizeOperation( + const std::string &name, + AclNNSplitWithSizeParam param) : AclNNOperation(name), param_(param) +{ + outputTensorVector.resize(param.num); +} + +SplitWithSizeOperation::~SplitWithSizeOperation() { +} + +atb::Status SplitWithSizeOperation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "SplitWithSizeOperation infer shape start"); + int splitSize = inTensorDescs.at(DIM0).shape.dims[param_.dim] / param_.num; + int remainSize = inTensorDescs.at(DIM0).shape.dims[param_.dim] % param_.num; + for (size_t i = 0; i < GetOutputNum(); i++) { + outTensorDescs.at(i) = inTensorDescs.at(DIM0); + if (i < static_cast(remainSize)) { + outTensorDescs.at(i).shape.dims[param_.dim] = splitSize + 1; + } else { + outTensorDescs.at(i).shape.dims[param_.dim] = splitSize; + } + } + ATB_SPEED_LOG_DEBUG(opName_ << "SplitWithSizeOperation infer shape end"); + return 0; +} + +uint32_t SplitWithSizeOperation::GetInputNum() const +{ + // 1: aclInTensors size + return 1; +} + +uint32_t SplitWithSizeOperation::GetOutputNum() const +{ + return param_.num; +} + +int SplitWithSizeOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret = 0; + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + return ret; + } + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +int SplitWithSizeOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); i++) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(variantPack.inTensors.at(i)); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + squeezedAtbTensor.desc.shape.dims, squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.desc.dtype, aclnnTensor->strides.data(), 0, + squeezedAtbTensor.desc.format, squeezedAtbTensor.desc.shape.dims, + squeezedAtbTensor.desc.shape.dimNum, squeezedAtbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int SplitWithSizeOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(variantPack.outTensors.at(i)); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + squeezedAtbTensor.desc.shape.dims, squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.desc.dtype, aclnnTensor->strides.data(), 0, + squeezedAtbTensor.desc.format, squeezedAtbTensor.desc.shape.dims, + squeezedAtbTensor.desc.shape.dimNum, squeezedAtbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " OutTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); i++) { + outputTensorVector[i] = aclnnVariantPack.aclOutTensors.at(i)->tensor; + } + return atb::NO_ERROR; +} + +int SplitWithSizeOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclTensorList *out = aclCreateTensorList(outputTensorVector.data(), outputTensorVector.size()); + + std::vector splitSizeVec; + int splitSize = aclnnVariantPack.aclInTensors.at(DIM0)->atbTensor.desc.shape.dims[param_.dim] / param_.num; + int remainSize = aclnnVariantPack.aclInTensors.at(DIM0)->atbTensor.desc.shape.dims[param_.dim] % param_.num; + for (size_t i = 0; i < GetOutputNum(); i++) { + if (i < static_cast(remainSize)) { + splitSizeVec.emplace_back(splitSize + 1); + } else { + splitSizeVec.emplace_back(splitSize); + } + } + aclIntArray *splitSizeIntArray = aclCreateIntArray(splitSizeVec.data(), splitSizeVec.size()); + + int ret = aclnnSplitWithSizeGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, splitSizeIntArray, param_.dim, out, + &this->aclnnOpCache_->workspaceSize, &this->aclnnOpCache_->aclExecutor); + + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int SplitWithSizeOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnSplitWithSize start"); + int ret = aclnnSplitWithSize( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnSplitWithSize end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/split_with_size_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/split_with_size_operation.h new file mode 100644 index 00000000..59f63ab9 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/split_with_size_operation.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_SPLIT_WITH_SIZE_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_SPLIT_WITH_SIZE_OPERATION_H +#include "aclnnop/aclnn_split_with_size.h" +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "operations/aclnn/core/acl_nn_tensor.h" + + +namespace atb_speed { +namespace common { + +struct AclNNSplitWithSizeParam { + int64_t dim = 0; + uint64_t num = 1; +}; + +class SplitWithSizeOperation : public AclNNOperation { +public: + explicit SplitWithSizeOperation(const std::string &name, AclNNSplitWithSizeParam param); + ~SplitWithSizeOperation() override; + uint32_t GetInputNum() const override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + AclNNSplitWithSizeParam param_; + std::vector outputTensorVector; +}; +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PLUGIN_ACLNN_SPLIT_WITH_SIZE_OPERATION_Hs \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/std_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/std_operation.cpp new file mode 100644 index 00000000..73f1d52d --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/std_operation.cpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "aclnnop/aclnn_std.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "std_operation.h" + +namespace atb_speed { +namespace common { + + +StdOperation::StdOperation(const std::string &name) : AclNNOperation(name) +{ + this->opName_ = name; +} + + +StdOperation::~StdOperation() +{ + if (dim != nullptr) { + aclDestroyIntArray(dim); + dim = nullptr; + } +} + +atb::Status StdOperation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "StdOperation infer shape start"); + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum; + for (uint64_t i = 0; i < inTensorDescs.at(0).shape.dimNum; ++i) { + outTensorDescs.at(0).shape.dims[i] = inTensorDescs.at(0).shape.dims[i]; + } + outTensorDescs.at(0).shape.dims[dimData.at(0)] = 1; + + ATB_SPEED_LOG_DEBUG(opName_ << "StdOperation infer shape end" + << " format: " << inTensorDescs.at(0).format << " dimNum: " << inTensorDescs.at(0).shape.dimNum + << " dims: " << inTensorDescs.at(0).shape.dims[0]); + return 0; +} + +uint32_t StdOperation::GetInputNum() const +{ + return NUM1; +} + + +uint32_t StdOperation::GetOutputNum() const +{ + return NUM1; +} + +atb::Status StdOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->tensorListidx = AclNNTensor::notInTensorList; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor atbTensor = variantPack.inTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, + atbTensor, aclnnTensor)); + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +atb::Status StdOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->tensorListidx = AclNNTensor::notInTensorList; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor atbTensor = variantPack.outTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, + atbTensor, aclnnTensor)); + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + + +int StdOperation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + if (dim == nullptr) { + dim = aclCreateIntArray(dimData.data(), 1); + } + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnStdGetWorkspaceSize(aclnnVariantPack.aclInTensors.at(0)->tensor, + dim, + 0, + true, + aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + + +int StdOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " StdOperation start"); + int ret = aclnnStd(workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " StdOperation end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/std_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/std_operation.h new file mode 100644 index 00000000..82de5513 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/std_operation.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_STD_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_STD_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { + +/// This class defines an operator that calculates the standard deviation of the input. +/// +/// This class makes uses of `aclnnStdGetWorkspaceSize` and `aclnnStd` from AscendCL Api. +/// +/// Inputs to the operator: +/// Name | Dtype | Shape | +/// -------------|---------------------|-------| +/// input | float16 or bfloat16 | [m,n] | +/// +/// Outputs of the operator: +/// Name | Dtype | Shape | +/// -------------|---------------------|-------| +/// output | float16 or bfloat16 | [m,n] | +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// INPUT = 0, +/// OUT, +/// }; +/// atb::Node &stdNode = opGraph.nodes.at(nodeId++); +/// stdNode.operation = new atb_speed::common::StdOperation("SparseMoeStdNode"); +/// stdNode.inTensorIds = {INPUT}; +/// stdNode.outTensorIds = {OUTPUT}; +/// \endcode + +class StdOperation : public AclNNOperation { +public: + explicit StdOperation(const std::string &name); + ~StdOperation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + + std::vector dimData = {1}; + aclIntArray *dim = nullptr; +}; +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/vector_norm_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/vector_norm_operation.cpp new file mode 100644 index 00000000..9afc67b3 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/vector_norm_operation.cpp @@ -0,0 +1,226 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "aclnnop/aclnn_linalg_vector_norm.h" +#include "vector_norm_operation.h" + + +namespace atb_speed::common { + + VectorNormOperation::VectorNormOperation( + const std::string &name, + atb_speed::common::AclNNVectorNormParam param + ) : AclNNOperation(name), param_(param) + { + this->opName_ = name; + this->param_ = param; + } + + VectorNormOperation::~VectorNormOperation() + { + ATB_SPEED_LOG_DEBUG("VectorNormOperation deconstruct"); + if (dims != nullptr) { + aclDestroyIntArray(dims); + } + if (param_.ord != nullptr) { + aclDestroyScalar(param_.ord); + } + + this->DestroyOperation(); + } + + /** + * + * @param[in] inTensorDesc: dimNum = 3, [batch_size, seq_len, hidden_size] + * @param[in] outTensorDesc: dimNum = 3, [batch_size, seq_len, hidden_size] + * @return atb::Status + */ + atb::Status VectorNormOperation::InferShape( + const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc + ) const + { + ATB_SPEED_LOG_DEBUG(opName_ << " InferShape start"); + outTensorDesc.at(0).format = inTensorDesc.at(0).format; + outTensorDesc.at(0).dtype = inTensorDesc.at(0).dtype; + outTensorDesc.at(0).shape.dimNum = inTensorDesc.at(0).shape.dimNum; + + if (inTensorDesc.at(0).shape.dimNum == DIM3) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 3] CHECK " << opName_ << " input shape: [input0] " + << inTensorDesc.at(0).shape.dims[DIM0] << ", " + << inTensorDesc.at(0).shape.dims[DIM1] << ", " + << inTensorDesc.at(0).shape.dims[DIM2]); + outTensorDesc.at(0).shape.dims[DIM0] = inTensorDesc.at(0).shape.dims[DIM0]; + outTensorDesc.at(0).shape.dims[DIM1] = inTensorDesc.at(0).shape.dims[DIM1]; + outTensorDesc.at(0).shape.dims[DIM2] = inTensorDesc.at(0).shape.dims[DIM2]; + } else if (inTensorDesc.at(0).shape.dimNum == DIM2) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 2] CHECK " << opName_ << " input shape: [input0] " + << inTensorDesc.at(0).shape.dims[DIM0] << ", " + << inTensorDesc.at(0).shape.dims[DIM1]); + outTensorDesc.at(0).shape.dims[DIM0] = inTensorDesc.at(0).shape.dims[DIM0]; + outTensorDesc.at(0).shape.dims[DIM1] = 1; + } else { + ATB_SPEED_LOG_ERROR(opName_ << " invalid dimNum = " << inTensorDesc.at(0).shape.dimNum); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " InferShape end"); + return atb::NO_ERROR; + } + + uint32_t VectorNormOperation::GetInputNum() const + { + return NUM1; // inputTensorNum = 1 + } + + uint32_t VectorNormOperation::GetOutputNum() const + { + return NUM1; // outputTensorNum = 1 + } + + atb::Status VectorNormOperation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) + { + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclTensor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(variantPack.inTensors.size()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + aclnnVariantPack.aclInTensors[i] = CreateTensor(variantPack.inTensors.at(i), i); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " Create aclInTensor end"); + + aclnnVariantPack.aclOutTensors.resize(variantPack.outTensors.size()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + aclnnVariantPack.aclOutTensors[i] = CreateTensor(variantPack.outTensors.at(i), i); + } + + ATB_SPEED_LOG_DEBUG(opName_ << " Create aclOutTensor end"); + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclTensor end"); + return 0; + } + + atb::Status VectorNormOperation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) + { + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.inTensors.at(i), i); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnTensor = " << aclnnTensor); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; + } + + atb::Status VectorNormOperation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) + { + AclNNVariantPack &aclNnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclNnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclNnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = CreateTensor(variantPack.outTensors.at(i), i); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " outTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclNnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; + } + + std::shared_ptr VectorNormOperation::CreateTensor(atb::Tensor atbTensor, size_t tensorIdx) + { + ATB_SPEED_LOG_DEBUG(opName_ << " CreateTensor start"); + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = static_cast(tensorIdx); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + ATB_SPEED_LOG_DEBUG(opName_ << " atbTensor.shape0 = " << atbTensor.desc.shape.dims[0]); + ATB_SPEED_LOG_DEBUG(opName_ << " atbTensor.shape1 = " << atbTensor.desc.shape.dims[1]); + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(atbTensor); + ATB_SPEED_LOG_DEBUG(opName_ << " squeezedAtbTensor.shape0 = " << squeezedAtbTensor.desc.shape.dims[0]); + ATB_SPEED_LOG_DEBUG(opName_ << " squeezedAtbTensor.shape1 = " << squeezedAtbTensor.desc.shape.dims[1]); + ATB_SPEED_LOG_DEBUG(opName_ << " tensor dtype: " << squeezedAtbTensor.desc.dtype); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + ATB_SPEED_LOG_DEBUG(opName_ << " squeezedAtbTensor.desc.shape.dims0 = " << + squeezedAtbTensor.desc.shape.dims[0]); + ATB_SPEED_LOG_DEBUG(opName_ << " squeezedAtbTensor.desc.shape.dims1 = " << + squeezedAtbTensor.desc.shape.dims[1]); + ATB_SPEED_LOG_DEBUG(opName_ << " squeezedAtbTensor.desc.shape.dimNum = " << + squeezedAtbTensor.desc.shape.dimNum); + ATB_SPEED_LOG_DEBUG(opName_ << " squeezedAtbTensor.desc.dtype = " << + squeezedAtbTensor.desc.dtype); + ATB_SPEED_LOG_DEBUG(opName_ << " squeezedAtbTensor.deviceData = " << + squeezedAtbTensor.deviceData); + aclnnTensor->tensor = aclCreateTensor( + squeezedAtbTensor.desc.shape.dims, + squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.desc.dtype, + aclnnTensor->strides.data(), + 0, + squeezedAtbTensor.desc.format, + squeezedAtbTensor.desc.shape.dims, + squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.deviceData); + ATB_SPEED_LOG_DEBUG(opName_ << " CreateTensor end"); + return aclnnTensor; + } + + int VectorNormOperation::SetAclNNWorkspaceExecutor() + { + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + float ord = 1.0; + param_.ord = aclCreateScalar(&ord, aclDataType::ACL_FLOAT); + std::vector dimData = { -1 }; + if (dims == nullptr) { + dims = aclCreateIntArray(dimData.data(), 1); + } + + int ret = aclnnLinalgVectorNormGetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, + param_.ord, + dims, + true, + aclDataType::ACL_FLOAT16, + aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end" + << ", ret: " << ret + << ", workspaceSize: " << this->aclnnOpCache_->workspaceSize + << ", aclExecutor: " << this->aclnnOpCache_->aclExecutor); + return ret; + } + + int VectorNormOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) + { + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp start"); + int ret = aclnnLinalgVectorNorm( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + ATB_SPEED_LOG_DEBUG(opName_ << " ExecuteAclNNOp end" + << ", ret: " << ret); + return ret; + } + +} // namespace atb_speed::common diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/vector_norm_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/vector_norm_operation.h new file mode 100644 index 00000000..825a1006 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/vector_norm_operation.h @@ -0,0 +1,121 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_PLUGIN_ACLNN_VECTOR_NORM_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_VECTOR_NORM_OPERATION_H + +#include "operations/aclnn/core/acl_nn_operation.h" + + +namespace atb_speed::common { + /// A struct defines `VectorNorm`'s parameter. + struct AclNNVectorNormParam { + /// scalar param, default is null + aclScalar *ord = nullptr; + }; + + /// vector norm operation is used for moe scenarios, for example: + /// Operation's Inputs: + /// Name | Dtype | Shape | + /// -------|---------------------|-------| + /// input | float16 or bfloat16 | [m,n] | + /// + /// Operation's Outputs: + /// Name | Dtype | Shape | + /// -------|---------------------|-------| + /// output | float16 or bfloat16 | [m,n] | + /// + /// Example: + /// \code + /// enum TensorIdx : uint32_t { + /// INPUT = 0, + /// OUT, + /// }; + /// atb::Node &vectorNormNode = opGraph.nodes.at(nodeId++); + /// atb_speed::common::AclNNVectorNormParam vectorNormParam; + /// vectorNormNode.operation = new atb_speed::common::VectorNormOperation("vectorNormOperation", vectorNormParam); + /// vectorNormNode.inTensorIds = {INPUT}; + /// vectorNormNode.outTensorIds = {OUT}; + /// \endcode + class VectorNormOperation : public AclNNOperation { + public: + /// Class constructor. + /// Initialize an `VectorNormOperation` pointer. + /// \param name The name of the AclNN operation. + /// \param param The param of the AclNN operation. + explicit VectorNormOperation(const std::string &name, AclNNVectorNormParam param); + + ~VectorNormOperation() override; + + /// infer shape function. + /// \param inTensorDesc inTensorDesc of AclNN operation. + /// \param outTensorDesc outTensorDesc of the AclNN operation. + atb::Status InferShape( + const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc + ) const override; + + /// get input num + /// \return operation input num + [[nodiscard]] uint32_t GetInputNum() const override; + + /// get output num + /// \return operation output num + [[nodiscard]] uint32_t GetOutputNum() const override; + + protected: + /// Prepare the operation's input tensors. + /// \param variantPack An `atb::VariantPack` object containing tensor info passed through ATB framework. + /// \return A status code that indicates whether variantPack was created successfully. + atb::Status CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + + /// Call AclNN operation's first phase API to get work space size and `aclOpExecutor`. + /// \return The return value of AclNN's first phase API. + int SetAclNNWorkspaceExecutor() override; + + /// Call AclNN operation's second phase API to execute the operation. + /// \return The return value of AclNN's second phase API. + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + /// Prepare the operation's input tensors. + /// \param variantPack An `atb::VariantPack` object containing tensor info passed through ATB framework. + /// \return A status code that indicates whether variantPack was created successfully. + atb::Status CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + + /// Prepare the operation's output tensors. + /// \param variantPack An `atb::VariantPack` object containing tensor info passed through ATB framework. + /// \return A status code that indicates whether variantPack was created successfully. + atb::Status CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + + /// Prepare the operation's create tensors. + /// \param atbTensor An `atb::Tensor` object containing tensor info passed through ATB framework. + /// \param tensorIdx The id of tensor. + /// \return A status code that indicates whether variantPack was created successfully. + virtual std::shared_ptr CreateTensor(atb::Tensor atbTensor, size_t tensorIdx); + + /// The dims of vector_norm. + aclIntArray *dims = nullptr; + + private: + /// An `AclNNVectorNormParam` object that can be reused within the current operation object. + AclNNVectorNormParam param_; + + /// A human identifiable name for the operation's name. + std::string opName_; + }; +} // namespace atb_speed::common + +#endif // ATB_SPEED_PLUGIN_ACLNN_VECTOR_NORM_OPERATION_H diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/w16a16_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/w16a16_operation.cpp new file mode 100644 index 00000000..7c5bfe39 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/w16a16_operation.cpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "w16a16_operation.h" +#include +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/timer.h" +#include "operations/aclnn/utils/utils.h" + +namespace atb_speed { +namespace common { + +W16A16Operation::W16A16Operation( + const std::string &name, + AclNNMatmulParam param) : AclNNOperation(name), param_(param) { +} + +W16A16Operation::~W16A16Operation() { +} + +atb::Status W16A16Operation::InferShape( + const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << "W16A16Operation infer shape start"); + + outTensorDescs.at(DIM0).format = inTensorDescs.at(DIM0).format; + outTensorDescs.at(DIM0).dtype = inTensorDescs.at(DIM0).dtype; + outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; + + int nDim = param_.transposeB ? DIM0 : DIM1; + ATB_SPEED_LOG_DEBUG(opName_ << "W16A16Operation infer shape origin inTensorDescs.at(DIM0).shape.dims[DIM0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0]); + ATB_SPEED_LOG_DEBUG(opName_ << "W16A16Operation infer shape origin inTensorDescs.at(DIM1).shape.dims[nDim]" + << inTensorDescs.at(DIM1).shape.dims[nDim]); + + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM1).shape.dims[nDim]; + + ATB_SPEED_LOG_DEBUG(opName_ << "W16A16Operation infer shape end"); + return 0; +} + +uint32_t W16A16Operation::GetInputNum() const +{ + uint32_t inputNum = DIM2; + if (param_.hasBias) { + inputNum += DIM1; + } + return inputNum; +} + +uint32_t W16A16Operation::GetOutputNum() const +{ + return DIM1; +} + +int W16A16Operation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret = 0; + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + return ret; + } + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +atb::Dims W16A16Operation::GetWeightStorageShape(const atb::TensorDesc atbTensorDesc) const +{ + atb::Dims storageTensorDims = atbTensorDesc.shape; // ND格式下,storageShape和originalShape一致 + if (atbTensorDesc.format == ACL_FORMAT_FRACTAL_NZ) { + // nz格式 (k, n) => (n / 16, k / 16, 16, 16) + // nz格式 (n, k) => (k / 16, n / 16, 16, 16) + storageTensorDims.dimNum = NUM4; // 4维 + auto dim0 = atbTensorDesc.shape.dims[DIM0]; + uint32_t blockSize = 16; // 16: 以blockSize大小16进行切分 + storageTensorDims.dims[DIM0] = atbTensorDesc.shape.dims[DIM1] / blockSize; + storageTensorDims.dims[DIM1] = dim0 / blockSize; + storageTensorDims.dims[DIM2] = blockSize; + storageTensorDims.dims[DIM3] = blockSize; + } + return storageTensorDims; +} + +int W16A16Operation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); i++) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(variantPack.inTensors.at(i)); + + // StorageShape + atb::Dims storageTensorDims = GetWeightStorageShape(squeezedAtbTensor.desc); + + // ViewShape and Stride + atb::Dims viewDims = squeezedAtbTensor.desc.shape; + if (i == 1 && this->param_.transposeB) { + aclnnTensor->strides = GetTransposeTensorStride(viewDims); + viewDims.dims[DIM0] = squeezedAtbTensor.desc.shape.dims[DIM1]; + viewDims.dims[DIM1] = squeezedAtbTensor.desc.shape.dims[DIM0]; + } else { + aclnnTensor->strides = GetCopyTensorStride(viewDims); + } + + aclnnTensor->tensor = aclCreateTensor( + viewDims.dims, viewDims.dimNum, squeezedAtbTensor.desc.dtype, + aclnnTensor->strides.data(), 0, squeezedAtbTensor.desc.format, + storageTensorDims.dims, storageTensorDims.dimNum, squeezedAtbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int W16A16Operation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(GetOutputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(variantPack.outTensors.at(i)); + aclnnTensor->strides = GetCopyTensorStride(squeezedAtbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + squeezedAtbTensor.desc.shape.dims, squeezedAtbTensor.desc.shape.dimNum, + squeezedAtbTensor.desc.dtype, aclnnTensor->strides.data(), 0, + squeezedAtbTensor.desc.format, squeezedAtbTensor.desc.shape.dims, + squeezedAtbTensor.desc.shape.dimNum, squeezedAtbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " OutTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int W16A16Operation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + float zeroValue = 0.0f; + float oneValue = 1.0f; + aclScalar* betaZero = aclCreateScalar(&zeroValue, aclDataType::ACL_FLOAT); + aclScalar* betaOne = aclCreateScalar(&oneValue, aclDataType::ACL_FLOAT); + + int ret = aclnnAddmmGetWorkspaceSize( + param_.hasBias ? aclnnVariantPack.aclInTensors.at(DIM2)->tensor + : aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + aclnnVariantPack.aclInTensors.at(DIM0)->tensor, + aclnnVariantPack.aclInTensors.at(DIM1)->tensor, + param_.hasBias ? betaOne : betaZero, + betaOne, + aclnnVariantPack.aclOutTensors.at(DIM0)->tensor, + 0, + &this->aclnnOpCache_->workspaceSize, &this->aclnnOpCache_->aclExecutor); + + ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret + << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int W16A16Operation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddmm start"); + int ret = aclnnAddmm( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnAddmm end, ret:" << ret); + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/w16a16_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/w16a16_operation.h new file mode 100644 index 00000000..13591e8e --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/w16a16_operation.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_W16A16_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_W16A16_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "aclnnop/aclnn_addmm.h" + +namespace atb_speed { +namespace common { + +/// A struct defines `W16A16Operation`'s parameter. +struct AclNNMatmulParam { + /// A flag indicating whether the second matrix in the matmul operation is transposed. + bool transposeB = false; + /// A flag indicating whether the matmul operation includes a bias tensor. + bool hasBias = false; +}; + +/// This class defines a matrix operation combines the matmul and add bias operation. +/// +/// This class makes use of `aclnnAddmmGetWorkspaceSize` and `aclnnAddmm` from the AscendCL API. +/// +/// Operation's Inputs: +/// Name | Dtype | Shape | Description | +/// ----------------|-----------------------------|-------|-------------| +/// input | FLOAT, FLOAT16, BFLOAT16 | [m,k] | | +/// weight | FLOAT, FLOAT16, BFLOAT16 | [n,k] if `transposeB` is true; otherwise, [k,n] | | +/// bias | FLOAT, FLOAT16, BFLOAT16 | [m,n] or can be broadcasted to [m,n] | Optional. Required if `hasBias` is true. | +/// +/// Operations's Outputs: +/// Name | Dtype | Shape | +/// -------|------------------------------------|-------| +/// out | the same dtype as the input tensor | [m,n] | +/// +/// Example: +/// \code +/// enum InTensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_WEIGHT, +/// IN_BIAS, +/// OUT, +/// }; +/// +/// atb::Node linearNode; +/// AclNNMatmulParam aclNNMatmulParam; +/// aclNNMatmulParam.hasBias = false; +/// aclNNMatmulParam.transposeB = true; +/// linearNode.inTensorIds = {IN_INPUT, IN_WEIGHT}; +/// linearNode.outTensorIds = {OUT}; +/// linearNode.operation = new atb_speed::common::W16A16Operation("W16A16LinearNode", aclNNMatmulParam); +/// +/// atb::Node linearWithBiasNode; +/// AclNNMatmulParam aclNNMatmulParam; +/// aclNNMatmulParam.hasBias = true; +/// aclNNMatmulParam.transposeB = true; +/// linearWithBiasNode.inTensorIds = { +/// IN_INPUT, IN_WEIGHT, IN_BIAS}; +/// linearWithBiasNode.outTensorIds = {OUT}; +/// linearWithBiasNode.operation = new atb_speed::common::W16A16Operation( +/// "W16A16LinearWithBiasNode", aclNNMatmulParam); +/// +/// // Add the operation node to the graph as required +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(linearNode); +/// opGraph.nodes.push_back(linearWithBiasNode); +/// \endcode + +class W16A16Operation : public AclNNOperation { +public: + explicit W16A16Operation(const std::string &name, AclNNMatmulParam param); + ~W16A16Operation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + + atb::Dims GetWeightStorageShape(const atb::TensorDesc atbTensorDesc) const; + + AclNNMatmulParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PUBLIC_ACLNN_W8A8_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/w4a16_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/w4a16_operation.cpp new file mode 100644 index 00000000..f3aef149 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/w4a16_operation.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/check_util.h" +#include "operations/aclnn/utils/utils.h" +#include "w4a16_operation.h" + +namespace atb_speed { +namespace common { + +W4A16Operation::W4A16Operation( + const std::string &name, + AclNNWeightQuantBatchMatmulParam param) : QuantBatchMatmulOperation(name, param), param_(param) {} + +atb::Tensor W4A16Operation::PreprocessATBInTensor(atb::Tensor atbTensor, int index) +{ + atb::Tensor squeezedAtbTensor = SqueezeBatchSeq(atbTensor); + if (index == 1) { // 1: weight + squeezedAtbTensor.desc.dtype = ACL_INT4; + squeezedAtbTensor.desc.shape.dims[DIM1] = CheckIntMulOverFlow( + squeezedAtbTensor.desc.shape.dims[DIM1], 2); // 2: 最后一维shape * 2 + } + return squeezedAtbTensor; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/w4a16_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/w4a16_operation.h new file mode 100644 index 00000000..3c707561 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/w4a16_operation.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_W4A16_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_W4A16_OPERATION_H +#include "quant_batch_matmul_operation.h" + +namespace atb_speed { +namespace common { +/// This class defines a matrix operation that supports 4-bit weight quantization +/// while keeping activations in floating-point format. +/// +/// It inherits from the `QuantBatchMatmulOperation` class. +/// +/// Operation's Inputs (per channel): +/// Name | Dtype | Shape | +/// ----------------|---------|-------| +/// input | int8 | [m,k] | +/// weight | int8 | [n,k/2] if `transposeB` is true; otherwise, [k,n/2] | +/// antiquant scale | the same dtype as the output tensor | [n,1] if `transposeB` is true; otherwise, [1,n] | +/// antiquant offset| the same dtype as the output tensor | [n,1] if `transposeB` is true; otherwise, [1,n] | +/// bias | int32 if the output tensor's dtype is float16; bfloat16 if the output tensor's dtype is bfloat16 | [n] | +/// +/// Operation's Inputs (per group): +/// Name | Dtype | Shape | +/// ----------------|---------|-------| +/// input | int8 | [m,k] | +/// weight | int8 | [n,k/2] if `transposeB` is true; otherwise, [k,n/2] | +/// antiquant scale | the same dtype as the output tensor | [n,ceil(k, group_size)] if `transposeB` is true; otherwise, [ceil(k, group_size),n] | +/// antiquant offset| the same dtype as the output tensor | [n,ceil(k, group_size)] if `transposeB` is true; otherwise, [ceil(k, group_size),n] | +/// bias | int32 if the output tensor's dtype is float16; bfloat16 if the output tensor's dtype is bfloat16 | [n] | +/// +/// Notice that Since PyTorch does not support the int4 data type, the weight tensor is provided as an 8-bit tensor. +/// It is treated as a 4-bit tensor, with the last axis doubled in the `PreprocessATBInTensor` function. +/// The device data stored in the weight tensor remains unchanged. +/// +/// Operations's Outputs: +/// Name | Dtype | Shape | +/// -------|---------------------|-------| +/// output | float16 or bfloat16 | [m,n] | +/// +/// Example: +/// \code +/// enum InTensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_WEIGHT, +/// IN_WEIGHT_ANTI_QUANT_SCALE, +/// IN_WEIGHT_ANTI_QUANT_OFFSET, +/// IN_BIAS, +/// OUT, +/// }; +/// +/// atb::Node linearNode; +/// AclNNWeightQuantBatchMatmulParam aclnnQuantBatchMatmulParam; +/// aclnnQuantBatchMatmulParam.hasBias = false; +/// aclnnQuantBatchMatmulParam.transposeB = true; +/// aclnnQuantBatchMatmulParam.quantGroupSize = 0; // 0: per channel; otherwise, per group +/// linearNode.inTensorIds = {IN_INPUT, IN_WEIGHT, IN_WEIGHT_ANTI_QUANT_SCALE, IN_WEIGHT_ANTI_QUANT_OFFSET}; +/// linearNode.outTensorIds = {OUT}; +/// linearNode.operation = new atb_speed::common::W4A16Operation("W4A16LinearNode", aclnnQuantBatchMatmulParam); +/// +/// atb::Node linearWithBiasNode; +/// AclNNWeightQuantBatchMatmulParam aclnnQuantBatchMatmulWithBiasParam; +/// aclnnQuantBatchMatmulWithBiasParam.hasBias = true; +/// aclnnQuantBatchMatmulWithBiasParam.transposeB = true; +/// linearWithBiasNode.inTensorIds = { +/// IN_INPUT, IN_WEIGHT, IN_WEIGHT_ANTI_QUANT_SCALE, IN_WEIGHT_ANTI_QUANT_OFFSET, IN_BIAS}; +/// linearWithBiasNode.outTensorIds = {OUT}; +/// linearWithBiasNode.operation = new atb_speed::common::W4A16Operation( +/// "W4A16LinearWithBiasNode", aclnnQuantBatchMatmulWithBiasParam); +/// +/// // Add the operation node to the graph as required +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(linearNode); +/// opGraph.nodes.push_back(linearWithBiasNode); +/// \endcode +class W4A16Operation : public QuantBatchMatmulOperation { +public: + explicit W4A16Operation(const std::string &name, AclNNWeightQuantBatchMatmulParam param); + +protected: + atb::Tensor PreprocessATBInTensor(atb::Tensor atbTensor, int index) override; + +private: + AclNNWeightQuantBatchMatmulParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/w4a8_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/w4a8_operation.cpp new file mode 100644 index 00000000..3c9879ed --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/w4a8_operation.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + */ +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "operations/aclnn/core/acl_nn_operation.h" +#include "aclnnop/aclnn_quant_matmul_v5.h" +#include "w4a8_operation.h" + +namespace atb_speed { +namespace common { + +W4A8Operation::W4A8Operation( + const std::string &name, + AclNNW4A8Param param) : AclNNOperation(name), param_(param) {} + +W4A8Operation::~W4A8Operation() +{ + ATB_SPEED_LOG_DEBUG("W4A8Operation deconstructor"); + this->DestroyOperation(); +} + +atb::Status W4A8Operation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + outTensorDescs.at(0).format = inTensorDescs.at(0).format; // ND或者NZ + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum; // tensor维度 + outTensorDescs.at(0).dtype = param_.outDataType; + if (inTensorDescs.at(0).shape.dimNum == DIM2) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 2] CHECK " << opName_ << " inputs shape: [input0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0] << ", " << inTensorDescs.at(DIM0).shape.dims[DIM1]); + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 2] CHECK " << opName_ << " inputs shape: [input1]" + << inTensorDescs.at(DIM1).shape.dims[DIM0] << ", " << inTensorDescs.at(DIM1).shape.dims[DIM1]); + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + // 8: int4 packed int32 + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM1).shape.dims[DIM1] * 8; + } else { + ATB_SPEED_LOG_ERROR(opName_ << " invalid dim num:" << inTensorDescs.at(DIM0).shape.dimNum); + return atb::ERROR_INVALID_TENSOR_DIM_NUM; + } + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return atb::NO_ERROR; +} + +uint32_t W4A8Operation::GetInputNum() const { return 5; } // 5: x, weight, x_scale, weight_scale, y_offset + +uint32_t W4A8Operation::GetOutputNum() const { return 1; } // 1: y + +int W4A8Operation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + atb::Tensor atbTensor = variantPack.inTensors.at(i); + if (i == 2) { // 2: x_scale + atbTensor.desc.shape.dimNum = 2; // 2: dimnum = 2, dims = [m, 1] + atbTensor.desc.shape.dims[1] = 1; // 1: dims[1] = 1 + } + int tensorIdx = i == 4 ? 7 : i; + std::shared_ptr aclnnTensor = CreateTensor(atbTensor, tensorIdx); + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int W4A8Operation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + int ret = aclnnQuantMatmulV5GetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // 0: input + aclnnVariantPack.aclInTensors.at(1)->tensor, // 1: weight + aclnnVariantPack.aclInTensors.at(2)->tensor, // 2: x1scale + aclnnVariantPack.aclInTensors.at(3)->tensor, // 3: x2scale + nullptr, // yscale + nullptr, // x1offset + nullptr, // x2offset + aclnnVariantPack.aclInTensors.at(4)->tensor, // yoffset + nullptr, // bias + false, // transposeX1 + false, // transposeX2 + param_.groupSize, // groupSize + aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " end, ret:" + << ret << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int W4A8Operation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + int ret = aclnnQuantMatmulV5( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + if (ret != 0) { + ATB_SPEED_LOG_ERROR("ExecuteAclNNOp failed, ret: " << ret); + } + return ret; +} + +} // namespace common +} // namespace atb_speed + diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/w4a8_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/w4a8_operation.h new file mode 100644 index 00000000..5ab6686a --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/w4a8_operation.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_W4A8_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_W4A8_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/fusion/utils.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { + +/// A struct defines `W4A8Operation`'s parameter. +struct AclNNW4A8Param { + /// A flag indicating whether the tensor type is bfloat16. + aclDataType outDataType = ACL_FLOAT16; + /// A flag indicating whether the matmul operation includes an offset tensor. + bool hasBias = false; + /// Group size of per group quantization. + int groupSize = 256; +}; + +class W4A8Operation : public AclNNOperation { +public: + explicit W4A8Operation(const std::string &name, AclNNW4A8Param param); + ~W4A8Operation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + +private: + AclNNW4A8Param param_; +}; +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PUBLIC_ACLNN_W4A8_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/w8a16_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/w8a16_operation.cpp new file mode 100644 index 00000000..3135c430 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/w8a16_operation.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "w8a16_operation.h" + +namespace atb_speed { +namespace common { + +W8A16Operation::W8A16Operation( + const std::string &name, + AclNNWeightQuantBatchMatmulParam param) : QuantBatchMatmulOperation(name, param), param_(param) {} + +atb::Tensor W8A16Operation::PreprocessATBInTensor(atb::Tensor atbTensor, int index) +{ + ATB_SPEED_LOG_DEBUG("W8A16 preprocess ATB in tensor " << index); + atb::Tensor squeezedATBTensor = SqueezeBatchSeq(atbTensor); + return squeezedATBTensor; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/w8a16_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/w8a16_operation.h new file mode 100644 index 00000000..2e4aabcd --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/w8a16_operation.h @@ -0,0 +1,98 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_W8A16_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_W8A16_OPERATION_H +#include "quant_batch_matmul_operation.h" + +namespace atb_speed { +namespace common { + +/// This class defines a matrix operation that supports 8-bit weight quantization +/// while keeping activations in floating-point format. +/// +/// It inherits from the `QuantBatchMatmulOperation` class. +/// +/// Operation's Inputs (per channel): +/// Name | Dtype | Shape | +/// ----------------|---------|-------| +/// input | int8 | [m,k] | +/// weight | int8 | [n,k] if `transposeB` is true; otherwise, [k,n] | +/// antiquant scale | the same dtype as the output tensor | [n,1] if `transposeB` is true; otherwise, [1,n] | +/// antiquant offset| the same dtype as the output tensor | [n,1] if `transposeB` is true; otherwise, [1,n] | +/// bias | int32 if the output tensor's dtype is float16; bfloat16 if the output tensor's dtype is bfloat16 | [n] | +/// +/// Operation's Inputs (per group): +/// Name | Dtype | Shape | +/// ----------------|---------|-------| +/// input | int8 | [m,k] | +/// weight | int8 | [n,k] if `transposeB` is true; otherwise, [k,n] | +/// antiquant scale | the same dtype as the output tensor | [n,ceil(k, group_size)] if `transposeB` is true; otherwise, [ceil(k, group_size),n] | +/// antiquant offset| the same dtype as the output tensor | [n,ceil(k, group_size)] if `transposeB` is true; otherwise, [ceil(k, group_size),n] | +/// bias | int32 if the output tensor's dtype is float16; bfloat16 if the output tensor's dtype is bfloat16 | [n] | +/// +/// Operations's Outputs: +/// Name | Dtype | Shape | +/// -------|---------------------|-------| +/// output | float16 or bfloat16 | [m,n] | +/// +/// Example: +/// \code +/// enum InTensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_WEIGHT, +/// IN_WEIGHT_ANTI_QUANT_SCALE, +/// IN_WEIGHT_ANTI_QUANT_OFFSET, +/// IN_BIAS, +/// OUT, +/// }; +/// +/// atb::Node linearNode; +/// AclNNWeightQuantBatchMatmulParam aclnnQuantBatchMatmulParam; +/// aclnnQuantBatchMatmulParam.hasBias = false; +/// aclnnQuantBatchMatmulParam.transposeB = true; +/// aclnnQuantBatchMatmulParam.quantGroupSize = 0; // 0: per channel; otherwise, per group +/// linearNode.inTensorIds = {IN_INPUT, IN_WEIGHT, IN_WEIGHT_ANTI_QUANT_SCALE, IN_WEIGHT_ANTI_QUANT_OFFSET}; +/// linearNode.outTensorIds = {OUT}; +/// linearNode.operation = new atb_speed::common::W8A16Operation("W8A16LinearNode", aclnnQuantBatchMatmulParam); +/// +/// atb::Node linearWithBiasNode; +/// AclNNWeightQuantBatchMatmulParam aclnnQuantBatchMatmulWithBiasParam; +/// aclnnQuantBatchMatmulWithBiasParam.hasBias = true; +/// aclnnQuantBatchMatmulWithBiasParam.transposeB = true; +/// linearWithBiasNode.inTensorIds = { +/// IN_INPUT, IN_WEIGHT, IN_WEIGHT_ANTI_QUANT_SCALE, IN_WEIGHT_ANTI_QUANT_OFFSET, IN_BIAS}; +/// linearWithBiasNode.outTensorIds = {OUT}; +/// linearWithBiasNode.operation = new atb_speed::common::W8A16Operation( +/// "W8A16LinearWithBiasNode", aclnnQuantBatchMatmulWithBiasParam); +/// +/// // Add the operation node to the graph as required +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(linearNode); +/// opGraph.nodes.push_back(linearWithBiasNode); +/// \endcode +class W8A16Operation : public QuantBatchMatmulOperation { +public: + explicit W8A16Operation(const std::string &name, AclNNWeightQuantBatchMatmulParam param); + +protected: + atb::Tensor PreprocessATBInTensor(atb::Tensor atbTensor, int index) override; + +private: + AclNNWeightQuantBatchMatmulParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/w8a8_operation.cpp b/tests/proftest/layer_test_framework/operations/aclnn/ops/w8a8_operation.cpp new file mode 100644 index 00000000..6a0455a1 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/w8a8_operation.cpp @@ -0,0 +1,241 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "acl/acl.h" +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "operations/aclnn/core/acl_nn_operation.h" +#include "aclnnop/aclnn_quant_matmul_v4.h" +#include "w8a8_operation.h" + +namespace atb_speed { +namespace common { + +W8A8Operation::W8A8Operation( + const std::string &name, + AclNNQuantMatmulParam param) : AclNNOperation(name), param_(param) {} + +W8A8Operation::~W8A8Operation() +{ + ATB_SPEED_LOG_DEBUG("W8A8Operation deconstructor"); + this->DestroyOperation(); +} + +atb::Status W8A8Operation::InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const +{ + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape start"); + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum; + // 外抛Dequant场景, MM输出为INT_32 + outTensorDescs.at(0).dtype = param_.isOutDequantBias ? ACL_INT32 : param_.isBF16 ? ACL_BF16 : ACL_FLOAT16; + + int nDim = param_.transposeB ? DIM0 : DIM1; + if (inTensorDescs.at(0).shape.dimNum == DIM3) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 3] CHECK " << opName_ << " inputs shape: [input0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0] << ", " + << inTensorDescs.at(DIM0).shape.dims[DIM1] << ", " << inTensorDescs.at(DIM0).shape.dims[DIM2]); + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 3] CHECK " << opName_ << " inputs shape: [input1]" + << inTensorDescs.at(DIM1).shape.dims[DIM0] << ", " << inTensorDescs.at(DIM1).shape.dims[DIM1]); + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM0).shape.dims[DIM1]; + outTensorDescs.at(DIM0).shape.dims[DIM2] = inTensorDescs.at(DIM1).shape.dims[nDim]; + } else if (inTensorDescs.at(0).shape.dimNum == DIM2) { + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 2] CHECK " << opName_ << " inputs shape: [input0]" + << inTensorDescs.at(DIM0).shape.dims[DIM0] << ", " << inTensorDescs.at(DIM0).shape.dims[DIM1]); + ATB_SPEED_LOG_DEBUG("[input0 dimNum = 2] CHECK " << opName_ << " inputs shape: [input1]" + << inTensorDescs.at(DIM1).shape.dims[DIM0] << ", " << inTensorDescs.at(DIM1).shape.dims[DIM1]); + outTensorDescs.at(DIM0).shape.dims[DIM0] = inTensorDescs.at(DIM0).shape.dims[DIM0]; + outTensorDescs.at(DIM0).shape.dims[DIM1] = inTensorDescs.at(DIM1).shape.dims[nDim]; + } else { + ATB_SPEED_LOG_ERROR(opName_ << " invalid dim num:" << inTensorDescs.at(DIM0).shape.dimNum); + } + ATB_SPEED_LOG_DEBUG(opName_ << " infer shape end"); + return 0; +} + +uint32_t W8A8Operation::GetInputNum() const +{ + uint32_t inputNum = 3; + ATB_SPEED_LOG_DEBUG("initial inputNum: " << inputNum); + if (param_.hasPerTokenScale) { + ATB_SPEED_LOG_DEBUG("QuantBatchMatmul & hasPerTokenScale"); + ++inputNum; + } + if (param_.hasBias) { + ATB_SPEED_LOG_DEBUG("QuantBatchMatmul & hasBias"); + ++inputNum; + } + ATB_SPEED_LOG_DEBUG("final inputNum: " << inputNum); + return inputNum; +} + +uint32_t W8A8Operation::GetOutputNum() const { return NUM1; } + +int W8A8Operation::CreateAclNNVariantPack(const atb::VariantPack &variantPack) +{ + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack start"); + int ret = 0; + ret = CreateAclNNInTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNInTensorVariantPack fail"); + return ret; + } + + ret = CreateAclNNOutTensorVariantPack(variantPack); + if (ret != 0) { + ATB_SPEED_LOG_ERROR(this->opName_ << " AclNNTensor CreateAclNNOutTensorVariantPack fail"); + return ret; + } + + ATB_SPEED_LOG_DEBUG(opName_ << " CreateAclNNVariantPack end"); + return atb::NO_ERROR; +} + +atb::Dims W8A8Operation::GetWeightStorageShape(const atb::TensorDesc atbTensorDesc) const +{ + atb::Dims storageTensorDims = atbTensorDesc.shape; // ND格式下,storageShape和originalShape一致 + // ND转NZ + if (atbTensorDesc.format == ACL_FORMAT_FRACTAL_NZ) { + // nz格式 (k, n) => (n / 32, k / 16, 16, 32) + // nz格式 (n, k) => (k / 32, n / 16, 16, 32) + storageTensorDims.dimNum = NUM4; // 4维 + auto dim0 = atbTensorDesc.shape.dims[DIM0]; + // m0、n0表示对齐位:float16:n0=m0=16, int8:n0=32,m0=16 + uint32_t blockSize = 16; // m0, 外轴 + uint32_t n0 = 32; // n0, 内轴, w8a8是int8 + storageTensorDims.dims[DIM0] = atbTensorDesc.shape.dims[DIM1] / n0; + storageTensorDims.dims[DIM1] = dim0 / blockSize; + storageTensorDims.dims[DIM2] = blockSize; + storageTensorDims.dims[DIM3] = n0; + } + return storageTensorDims; +} + +int W8A8Operation::CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclInTensors.resize(GetInputNum()); + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.inTensors.at(i); + atb::Tensor atbTensor = variantPack.inTensors.at(i); + + if (param_.matmulBackend == atb_speed::common::OpBackend::ACLNN) { + // StorageShape + atb::Dims storageTensorDims = GetWeightStorageShape(atbTensor.desc); + + // ViewShape and Stride + atb::Dims viewDims = atbTensor.desc.shape; + // aclInTensors[1]为weight + if (i == 1 && this->param_.transposeB) { + aclnnTensor->strides = GetTransposeTensorStride(viewDims); + viewDims.dims[DIM0] = atbTensor.desc.shape.dims[DIM1]; + viewDims.dims[DIM1] = atbTensor.desc.shape.dims[DIM0]; + } else { + aclnnTensor->strides = GetCopyTensorStride(viewDims); + } + // offset为0 + aclnnTensor->tensor = aclCreateTensor( + viewDims.dims, viewDims.dimNum, atbTensor.desc.dtype, + aclnnTensor->strides.data(), 0, atbTensor.desc.format, + storageTensorDims.dims, storageTensorDims.dimNum, atbTensor.deviceData); + } else { + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, atbTensor.desc.dtype, + aclnnTensor->strides.data(), 0, atbTensor.desc.format, atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, atbTensor.deviceData); + } + + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " InTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclInTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int W8A8Operation::CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) +{ + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + aclnnVariantPack.aclOutTensors.resize(NUM1); + for (size_t i = 0; i < aclnnVariantPack.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->tensorIdx = i; + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = variantPack.outTensors.at(i); + atb::Tensor atbTensor = variantPack.outTensors.at(i); + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + aclnnTensor->tensor = aclCreateTensor( + atbTensor.desc.shape.dims, atbTensor.desc.shape.dimNum, atbTensor.desc.dtype, + aclnnTensor->strides.data(), 0, atbTensor.desc.format, atbTensor.desc.shape.dims, + atbTensor.desc.shape.dimNum, atbTensor.deviceData); + + if (aclnnTensor->tensor == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " OutTensor aclCreateTensor index " << i << " fail"); + return atb::ERROR_INTERNAL_ERROR; + } + aclnnVariantPack.aclOutTensors[i] = aclnnTensor; + } + return atb::NO_ERROR; +} + +int W8A8Operation::SetAclNNWorkspaceExecutor() +{ + ATB_SPEED_LOG_DEBUG(opName_ << " start"); + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + uint32_t inputIdx = 3; + aclTensor* perTokenScaleTensor = param_.hasPerTokenScale ? \ + aclnnVariantPack.aclInTensors.at(inputIdx++)->tensor : nullptr; + // 外抛Dequant场景,biasTensor设置为nullptr + aclTensor* biasTensor = param_.isOutDequantBias ? nullptr : \ + param_.hasBias ? aclnnVariantPack.aclInTensors.at(inputIdx++)->tensor : nullptr; + int ret = aclnnQuantMatmulV4GetWorkspaceSize( + aclnnVariantPack.aclInTensors.at(0)->tensor, // 0: input + aclnnVariantPack.aclInTensors.at(1)->tensor, // 1: weight + aclnnVariantPack.aclInTensors.at(2)->tensor, // 2: scale + nullptr, // offset + perTokenScaleTensor, // per token scale + biasTensor, // bias + false, // transposeX1 + param_.transposeB, // transposeX2 + aclnnVariantPack.aclOutTensors.at(0)->tensor, + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + ATB_SPEED_LOG_DEBUG(opName_ << " end, ret:" + << ret << ", workspaceSize:" << this->aclnnOpCache_->workspaceSize + << ", aclExecutor:" << this->aclnnOpCache_->aclExecutor); + return ret; +} + +int W8A8Operation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) +{ + int ret = aclnnQuantMatmulV4( + workspace, + this->aclnnOpCache_->workspaceSize, + this->aclnnOpCache_->aclExecutor, + stream); + if (ret != 0) { + ATB_SPEED_LOG_ERROR("ExecuteAclNNOp failed, ret: " << ret); + } + return ret; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/ops/w8a8_operation.h b/tests/proftest/layer_test_framework/operations/aclnn/ops/w8a8_operation.h new file mode 100644 index 00000000..bc8c2654 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/ops/w8a8_operation.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLNN_W8A8_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLNN_W8A8_OPERATION_H +#include "operations/aclnn/core/acl_nn_operation.h" +#include "operations/fusion/utils.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { + +/// A struct defines `W8A8Operation`'s parameter. +struct AclNNQuantMatmulParam { + /// A flag indicating whether the matmul operation includes a bias tensor. + bool hasBias = false; + /// A flag indicating whether the second matrix in the matmul operation is transposed. + bool transposeB = true; + /// A flag indicating whether to use the atb matmul backend + int matmulBackend = atb_speed::common::OpBackend::ATB; + /// A flag indicating whether the matmul operation includes a perTokenScaleOptional tensor. + bool hasPerTokenScale = false; + /// A flag indicating whether the tensor type is bfloat16. + bool isBF16 = true; + /// A flag indicating whether the matmul operation throws out dequantBias operation. + bool isOutDequantBias = false; + /// A flag indicating whether the matmul operation includes an offset tensor. + bool hasOffset = false; +}; + +/// This class defines a matrix operation that supports +/// dynamic per-token activation quantization and weight per-channel quantization. +/// +/// This class makes use of `aclnnQuantMatmulV4GetWorkspaceSize` and `aclnnQuantMatmulV4` from the AscendCL API. +/// +/// Operation's Inputs: +/// Name | Dtype | Shape | +/// ----------------|---------|-------| +/// input | int8 | [m,k] | +/// weight | int8 | [n,k] if `transposeB` is true; otherwise, [k,n] | +/// weight scale | float32 if the output tensor's dtype is float16; bfloat16 if the output tensor's dtype is bfloat16 | [n] | +/// per token scale | float32 | [m] | +/// bias | int32 | [n] | +/// +/// Operations's Outputs: +/// Name | Dtype | Shape | +/// -------|---------------------|-------| +/// output | float16 or bfloat16 | [m,n] | +/// +/// Example: +/// \code +/// enum InTensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_WEIGHT, +/// IN_WEIGHT_SCALE, +/// IN_PER_TOKEN_SCALE, +/// IN_BIAS, +/// OUT, +/// }; +/// +/// atb::Node linearNode; +/// AclNNQuantMatmulParam aclnnQuantMatmulParam; +/// aclnnQuantMatmulParam.hasBias = false; +/// aclnnQuantMatmulParam.transposeB = true; +/// linearNode.inTensorIds = {IN_INPUT, IN_WEIGHT, IN_WEIGHT_SCALE, IN_PER_TOKEN_SCALE}; +/// linearNode.outTensorIds = {OUT}; +/// linearNode.operation = new atb_speed::common::W8A8Operation("W8A8LinearNode", aclnnQuantMatmulParam); +/// +/// atb::Node linearWithBiasNode; +/// AclNNQuantMatmulParam aclnnQuantMatmulWithBiasParam; +/// aclnnQuantMatmulWithBiasParam.hasBias = true; +/// aclnnQuantMatmulWithBiasParam.transposeB = true; +/// linearWithBiasNode.inTensorIds = {IN_INPUT, IN_WEIGHT, IN_WEIGHT_SCALE, IN_PER_TOKEN_SCALE, IN_BIAS}; +/// linearWithBiasNode.outTensorIds = {OUT}; +/// linearWithBiasNode.operation = new atb_speed::common::W8A8Operation( +/// "W8A8LinearWithBiasNode", aclnnQuantMatmulWithBiasParam); +/// +/// // Add the operation node to the graph as required +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(linearNode); +/// opGraph.nodes.push_back(linearWithBiasNode); +/// \endcode +class W8A8Operation : public AclNNOperation { +public: + explicit W8A8Operation(const std::string &name, AclNNQuantMatmulParam param); + ~W8A8Operation() override; + atb::Status InferShape(const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + int CreateAclNNVariantPack(const atb::VariantPack &variantPack) override; + int SetAclNNWorkspaceExecutor() override; + int ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) override; + int CreateAclNNInTensorVariantPack(const atb::VariantPack &variantPack) override; + int CreateAclNNOutTensorVariantPack(const atb::VariantPack &variantPack) override; + atb::Dims GetWeightStorageShape(const atb::TensorDesc atbTensorDesc) const; + +private: + AclNNQuantMatmulParam param_; +}; +} // namespace common +} // namespace atb_speed +#endif // ATB_SPEED_PUBLIC_ACLNN_W8A8_OPERATION_H \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/utils/utils.cpp b/tests/proftest/layer_test_framework/operations/aclnn/utils/utils.cpp new file mode 100644 index 00000000..c0a80f9f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/utils/utils.cpp @@ -0,0 +1,270 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "atb_speed/log.h" +#include "atb_speed/utils/check_util.h" +#include "utils.h" + +namespace atb_speed { +namespace common { + +atb::SVector GetCopyTensorStride(atb::Dims &tensorDims) +{ + atb::SVector tmpStrides(tensorDims.dimNum, 1); + if (tensorDims.dimNum > 8) { // 8: tensor最大维度数量 + ATB_SPEED_LOG_ERROR("Tensor's dimNum is larger than 8, `GetCopyTensorStride` failed."); + return tmpStrides; + } + for (int64_t i = static_cast(tensorDims.dimNum) - 2; i >= 0; i--) { + tmpStrides[i] = CheckIntMulOverFlow(tensorDims.dims[i + 1], tmpStrides[i + 1]); + } + return tmpStrides; +} + +atb::SVector GetTransposeTensorStride(atb::Dims &tensorDims) +{ + atb::SVector tmptransposeStrides(tensorDims.dimNum, 1); + tmptransposeStrides[tensorDims.dimNum - 1] = tensorDims.dims[tensorDims.dimNum - 1]; + if (tensorDims.dimNum == 3) { // 3: 维度 + tmptransposeStrides[0] = CheckIntMulOverFlow( // 0: 第0维 + tensorDims.dims[1], tensorDims.dims[2]); // 1, 2: 跳过第1维和第2维的大小 + } + return tmptransposeStrides; +} + +atb::Status CallAclCreateTensor(atb::Dims &viewDims, atb::Dims &storageDims, atb::Tensor &atbTensor, + std::shared_ptr aclnnTensor) +{ + aclnnTensor->tensor = aclCreateTensor(viewDims.dims, + viewDims.dimNum, + atbTensor.desc.dtype, + aclnnTensor->strides.data(), + 0, + atbTensor.desc.format, + storageDims.dims, + storageDims.dimNum, + atbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + return atb::ERROR_INTERNAL_ERROR; + } + return atb::NO_ERROR; +} + +bool IsA2() +{ + // 使用atb的判断逻辑:atb的更优 + const uint32_t lenOfAtlasA2 = 10; + std::string socName = aclrtGetSocName(); + ATB_SPEED_LOG_DEBUG("SocVersionName:" << std::string(socName)); + bool isA2 = (std::string(socName).find("Ascend910B") != std::string::npos && + std::string(socName).length() > lenOfAtlasA2) || + std::string(socName).find("Ascend910_93") != std::string::npos; + return isA2; +} + +bool IsA3() +{ + std::string socName = aclrtGetSocName(); + ATB_SPEED_LOG_DEBUG("SocVersionName:" << std::string(socName)); + bool isA3 = std::string(socName).find("Ascend910_93") != std::string::npos; + return isA3; +} + +bool Is310P() +{ + std::string socName = aclrtGetSocName(); + ATB_SPEED_LOG_DEBUG("SocVersionName:" << std::string(socName)); + bool is310P = std::string(socName).find("Ascend310P") != std::string::npos; + return is310P; +} + +atb::Tensor SqueezeBatchSeq(atb::Tensor atbTensor) +{ + if (atbTensor.desc.shape.dimNum == DIM3) { + atbTensor.desc.shape.dimNum = DIM2; + atbTensor.desc.shape.dims[DIM0] = CheckIntMulOverFlow( + atbTensor.desc.shape.dims[DIM0], atbTensor.desc.shape.dims[DIM1]); + atbTensor.desc.shape.dims[DIM1] = atbTensor.desc.shape.dims[DIM2]; + } + return atbTensor; +} + +std::string PrintAclNNVariankPack(const AclNNVariantPack &aclnnVariantPack) +{ + std::stringstream ss; + ss << "Plugin Op Cache: AclNNVariantPack "; + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); i++) { + const atb::TensorDesc &tensorDesc = aclnnVariantPack.aclInTensors[i]->atbTensor.desc; + ss << "index " << i << " dtype " << tensorDesc.dtype + << " format " << tensorDesc.format << " dimNum " << tensorDesc.shape.dimNum; + for (uint64_t j = 0; j < std::min(tensorDesc.shape.dimNum, static_cast(8)); j++) { // 8: tensor最大维度数量 + ss << "dim[" << j << "]=" << tensorDesc.shape.dims[j] << " "; + } + } + return ss.str(); +} + +std::string PrintATBVariankPack(const atb::VariantPack &atbVariantPack) +{ + std::stringstream ss; + ss << "Plugin Op Cache: ATBVariantPack "; + for (size_t i = 0; i < atbVariantPack.inTensors.size(); i++) { + const atb::TensorDesc &tensorDesc = atbVariantPack.inTensors[i].desc; + ss << "index " << i << " dtype " << tensorDesc.dtype + << " format " << tensorDesc.format << " dimNum " << tensorDesc.shape.dimNum; + for (uint64_t j = 0; j < std::min(tensorDesc.shape.dimNum, static_cast(8)); j++) { // 8: tensor最大维度数量 + ss << "dim[" << j << "]=" << tensorDesc.shape.dims[j] << " "; + } + } + return ss.str(); +} + +bool IsHostDataEqual(const std::shared_ptr tensorA, const atb::Tensor &tensorB, int tensorIdx) +{ + if (tensorA->intArrayHostData.intArray != nullptr && tensorB.hostData == nullptr) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: tensor index " << tensorIdx + << " aclnnVariantPack hostData is not null but atbVariantPack hostData is"); + return false; + } + if (tensorA->intArrayHostData.intArray == nullptr && tensorB.hostData != nullptr) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: tensor index " << tensorIdx + << " aclnnVariantPack hostData is null but atbVariantPack hostData is not"); + return false; + } + if (tensorA->intArrayHostData.intArray != nullptr && tensorB.hostData != nullptr) { + if (tensorA->intArrayHostData.dataOri.size() * 4 != tensorB.dataSize) { // 8: int64_t in bytes + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: tensor index " << tensorIdx << " dataSize not equal"); + return false; + } + if (memcmp(tensorA->intArrayHostData.dataOri.data(), tensorB.hostData, tensorB.dataSize) != 0) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: tensor index " << tensorIdx << " hostData not equal"); + return false; + } + } + return true; +} + +bool IsTensorDescEqual(const atb::TensorDesc &tensorDescA, const atb::TensorDesc &tensorDescB, int tensorIdx) +{ + if (tensorDescA.dtype != tensorDescB.dtype) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: tensor index " << tensorIdx + << " dtype not equal, aclnnVariantPack dtype " << tensorDescA.dtype + << " atbVariantPack dtype " << tensorDescB.dtype); + return false; + } + if (tensorDescA.format != tensorDescB.format) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: tensor index " << tensorIdx + << " format not equal, aclnnVariantPack format " << tensorDescA.format + << " atbVariantPack format " << tensorDescB.format); + return false; + } + if (tensorDescA.shape.dimNum != tensorDescB.shape.dimNum || \ + tensorDescA.shape.dimNum > 8 || tensorDescA.shape.dimNum <= 0) { // 8: tensor最大维度数量 + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: tensor index " << tensorIdx + << " dimNum not equal, aclnnVariantPack dimNum " << tensorDescA.shape.dimNum + << " atbVariantPack dimNum " << tensorDescB.shape.dimNum); + return false; + } + for (uint64_t j = 0; j < tensorDescA.shape.dimNum; j++) { + if (tensorDescA.shape.dims[j] != tensorDescB.shape.dims[j]) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: : tensor index " << tensorIdx + << " shape.dims " << j << " not equal, aclnnVariantPack value " + << tensorDescA.shape.dims[j] << " atbVariantPack value " << tensorDescB.shape.dims[j]); + return false; + } + } + return true; +} + +bool AreTensorVectorsEqual( + const atb::SVector> &aclnnTensors, const atb::SVector &atbTensors) +{ + // Check the size of two vectors + if (aclnnTensors.size() != atbTensors.size()) { + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: size not equal, aclnnVariantPack size " + << aclnnTensors.size() << " atbVariantPack size " + << atbTensors.size()); + return false; + } + + // Check if every tensor in each vector has consistent data type, format, shape and host data. + for (size_t i = 0; i < aclnnTensors.size(); i++) { + const std::shared_ptr tensorA = aclnnTensors[i]; + const atb::Tensor &tensorB = atbTensors[i]; + + if (!IsHostDataEqual(tensorA, tensorB, i)) { + return false; + } + + if (!IsTensorDescEqual(tensorA->atbTensor.desc, tensorB.desc, i)) { + return false; + } + } + + return true; +} + +bool IsVariankPackEqual(const AclNNVariantPack &aclnnVariantPack, const atb::VariantPack &atbVariantPack) +{ + ATB_SPEED_LOG_DEBUG(PrintAclNNVariankPack(aclnnVariantPack)); + ATB_SPEED_LOG_DEBUG(PrintATBVariankPack(atbVariantPack)); + + if (!AreTensorVectorsEqual(aclnnVariantPack.aclInTensors, atbVariantPack.inTensors)) { + return false; + } + + if (!AreTensorVectorsEqual(aclnnVariantPack.aclOutTensors, atbVariantPack.outTensors)) { + return false; + } + + ATB_SPEED_LOG_DEBUG("Plugin Op Cache: TensorDesc match"); + return true; +} + +std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx) +{ + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + aclnnTensor->tensorIdx = tensorIdx; + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensor); + return aclnnTensor; +} + +int ConvertTensorToSeqLengths(atb::Tensor &tensor, aclIntArray *&actualSeqLengths) +{ + static std::vector seqLenCache; + size_t dataSize = tensor.dataSize / 8; // 8: int64 size + if (seqLenCache.size() < dataSize) { + seqLenCache.resize(dataSize); + } + if (memcpy_s(seqLenCache.data(), dataSize * 8, tensor.hostData, dataSize * 8) != 0) { // 8: int64 size + ATB_SPEED_LOG_ERROR(" memcpy_s failed"); + return atb::ERROR_INTERNAL_ERROR; + } + if (actualSeqLengths != nullptr) { + aclDestroyIntArray(actualSeqLengths); + actualSeqLengths = nullptr; + } + actualSeqLengths = aclCreateIntArray(static_cast(seqLenCache.data()), dataSize); + return atb::NO_ERROR; +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclnn/utils/utils.h b/tests/proftest/layer_test_framework/operations/aclnn/utils/utils.h new file mode 100644 index 00000000..0959ab0f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclnn/utils/utils.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_PLUGIN_UTILS_H +#define ATB_SPEED_PLUGIN_UTILS_H +#include "operations/aclnn/core/acl_nn_operation_cache.h" +#include "operations/aclnn/core/acl_nn_tensor.h" + +namespace atb_speed { +namespace common { + +/// Constant to represent index +const int DIM0 = 0; +const int DIM1 = 1; +const int DIM2 = 2; +const int DIM3 = 3; +/// Constant to represent number +const int NUM1 = 1; +const int NUM2 = 2; +const int NUM3 = 3; +const int NUM4 = 4; +const int NUM5 = 5; +const int NUM6 = 6; +const int NUM7 = 7; +const int NUM8 = 8; +const int NUM9 = 9; +const int NUM10 = 10; + +/// Calculate stride along each dimension when copying a tensor. +/// +/// \param tensorDims The size of each axis in a tensor. +/// \return The number of steps needed in memory to move to the next element +/// along each dimension when copying a tensor. +atb::SVector GetCopyTensorStride(atb::Dims &tensorDims); + +/// Calculate stride along each dimension when transposing a tensor. +/// +/// \param tensorDims The size of each axis in a tensor. +/// \return The number of steps needed in memory to move to the next element +/// along each dimension when transposing a tensor. +atb::SVector GetTransposeTensorStride(atb::Dims &tensorDims); + +/// Call `aclCreateTensor` API to create `aclTensor`. +/// +/// \param viewDims The dimension of tensor's view shape. +/// \param storageDims The dimension of tensor's storage shape. +/// \param atbTensor The tensor passed through ATB framework. +/// \param aclnnTensor A pointer to an `AclNNTensor` object whose `tensor` attribute is updated +/// using the return value of `aclCreateTensor`. +/// \return A status code that indicates whether `aclTensor` has been created. +atb::Status CallAclCreateTensor(atb::Dims &viewDims, atb::Dims &storageDims, atb::Tensor &atbTensor, + std::shared_ptr aclnnTensor); + +/// Calling `aclrtGetSocName` API to get the hardware info. +/// +/// \return A bool value that indicates whether the hardware is A2. +bool IsA2(); + +/// Calling `aclrtGetSocName` API to get the hardware info. +/// +/// \return A bool value that indicates whether the hardware is A3. +bool IsA3(); + +/// Calling `aclrtGetSocName` API to get the hardware info. +/// +/// \return A bool value that indicates whether the hardware is 310P. +bool Is310P(); + +/// Reshape a tensor by squeezing batch size axis and seq len axis if the tensor's shape has two dimensions. +/// +/// \param atbTensor An `atb::Tensor` object whose tensor shape requires reshaping. +/// \return The `atb::Tensor` after reshaping. +atb::Tensor SqueezeBatchSeq(atb::Tensor atbTensor); + +/// Check whether `aclnnVariantPack` and `atbVariantPack` are the same, except for tensors' device data. +/// +/// Two variant packs, `aclnnVariantPack` and `atbVariantPack`, are considered the same if they have the same +/// number of tensors, and each corresponding tensor in the variant packs +/// has identical data type,format,shape and host data. +/// \param aclnnVariantPack An `AclNNVariantPack` object containing tensor info of existing AclNN operation. +/// \param atbVariantPack An `atb::VariantPack` object containing tensor info passed through ATB framework. +/// \return A boolean value that indicates whether `aclnnVariantPack` and `atbVariantPack` are the same, +/// except for tensors' device data. +bool IsVariankPackEqual(const AclNNVariantPack &aclnnVariantPack, const atb::VariantPack &atbVariantPack); + +/// Create a pointer to `AclNNTensor` by configuring it with tensor information extracted from `atbTensor`. +/// +/// `needUpdateTensorDataPtr` is set to true, `atbTensor` and `tensorIdx` are updated with input parameters. +/// `strides` is updated by `GetCopyTensorStride`. +/// \param atbTensor Tensor passed through the ATB framework. +/// \param tensorIdx The index of the tensor in `aclOpExecutor`'s parameter list. +/// \return A pointer to an `AclNNTensor` object whose attributes are updated. +std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx); + +int ConvertTensorToSeqLengths(atb::Tensor &tensor, aclIntArray *&actualSeqLengths); +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/aclrt/ops/aclrt_cmo_async.cpp b/tests/proftest/layer_test_framework/operations/aclrt/ops/aclrt_cmo_async.cpp new file mode 100644 index 00000000..3f1e49a7 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclrt/ops/aclrt_cmo_async.cpp @@ -0,0 +1,125 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/log.h" +#include "operations/aclnn/utils/utils.h" +#include "acl/acl.h" +#include "aclrt_cmo_async.h" + +namespace atb_speed { +namespace common { + +AclrtCmoAsyncOperation::AclrtCmoAsyncOperation(const std::string &opName) : opName_(opName) {} + +AclrtCmoAsyncOperation::~AclrtCmoAsyncOperation() +{ + ATB_SPEED_LOG_DEBUG("AclrtCmoAsyncOperation deconstructor"); +} + +std::string AclrtCmoAsyncOperation::GetName() const +{ + return this->opName_; +} + + +uint32_t AclrtCmoAsyncOperation::GetInputNum() const +{ + return NUM1; +} + +uint32_t AclrtCmoAsyncOperation::GetOutputNum() const +{ + return 0; +} + +atb::Status AclrtCmoAsyncOperation::InferShape(const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc) const +{ + ATB_SPEED_LOG_DEBUG("inTensorDesc size: " << inTensorDesc.size() << ", outTensorDesc size: " + << outTensorDesc.size()); + return atb::NO_ERROR; +} + +atb::Status AclrtCmoAsyncOperation::Setup(const atb::VariantPack &variantPack, uint64_t &workspaceSize, + atb::Context *context) +{ + ATB_SPEED_LOG_DEBUG("variantPack outTensors size: " + << variantPack.outTensors.size() + << ", workspaceSize: " + << workspaceSize); + ATB_SPEED_LOG_DEBUG(this->opName_ << " setup start"); + + if (context == nullptr) { + ATB_SPEED_LOG_ERROR(this->opName_ << " setup context is null"); + return atb::ERROR_INVALID_PARAM; + } + + workspaceSize = 0; + + ATB_SPEED_LOG_DEBUG("setup end"); + return atb::NO_ERROR; +} + +atb::Status AclrtCmoAsyncOperation::Execute(const atb::VariantPack &variantPack, uint8_t *workspace, + uint64_t workspaceSize, atb::Context *context) +{ + ATB_SPEED_LOG_DEBUG(this->opName_ << " execute start: "); + + if (!context) { + ATB_SPEED_LOG_ERROR(this->opName_ << " execute fail, context param is null. Enable log: " + << "export ASDOPS_LOG_LEVEL=ERROR, export ASDOPS_LOG_TO_STDOUT=1 to find the first error. " + << "For more details, see the MindIE official document." << std::endl, ATB_MODELS_EXECUTION_FAILURE); + return atb::ERROR_INVALID_PARAM; + } + + std::vector streams = context->GetExecuteStreams(); + + if (!streams[1]) { + ATB_SPEED_LOG_ERROR(this->opName_ << " execute fail, execute stream in context is null. " + << "Enable log: export ASDOPS_LOG_LEVEL=ERROR, export ASDOPS_LOG_TO_STDOUT=1 to find the first error. " + << "For more details, see the MindIE official document." << std::endl, ATB_MODELS_EXECUTION_FAILURE); + return atb::ERROR_INVALID_PARAM; + } + + aclrtCmoType cmoType = ACL_RT_CMO_TYPE_PREFETCH; + + ATB_SPEED_LOG_DEBUG("variantPack deviceData: " << variantPack.inTensors.at(0).deviceData + << " ,variantPack dataSize: " << variantPack.inTensors.at(0).dataSize + << " ,stream: " << streams[1]); + + CheckAcl(aclrtCmoAsync(variantPack.inTensors.at(0).deviceData, + variantPack.inTensors.at(0).dataSize, + cmoType, + streams[1])); + + ATB_SPEED_LOG_DEBUG("aclrtCmoAsync create success."); + + if (workspaceSize != 0 || workspace != nullptr) { + ATB_SPEED_LOG_DEBUG("execute workspace: " << workspaceSize); + } + + return atb::NO_ERROR; +} + +aclError AclrtCmoAsyncOperation::CheckAcl(aclError ret) const +{ + if (ret != ACL_ERROR_NONE) { + ATB_SPEED_LOG_ERROR(__FILE__ << ":" << __LINE__ << " aclError:" << ret); + } + return ret; +} + +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/aclrt/ops/aclrt_cmo_async.h b/tests/proftest/layer_test_framework/operations/aclrt/ops/aclrt_cmo_async.h new file mode 100644 index 00000000..302be057 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/aclrt/ops/aclrt_cmo_async.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PLUGIN_ACLRT_OPERATION_H +#define ATB_SPEED_PLUGIN_ACLRT_OPERATION_H +#include +#include +#include +#include +#include + +namespace atb_speed { +namespace common { + +class AclrtCmoAsyncOperation : public atb::OperationInfra { +public: + explicit AclrtCmoAsyncOperation(const std::string &opName); + + ~AclrtCmoAsyncOperation() override; + + std::string GetName() const override; + atb::Status InferShape(const atb::SVector &inTensorDesc, + atb::SVector &outTensorDesc) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + + atb::Status Setup(const atb::VariantPack &variantPack, uint64_t &workspaceSize, atb::Context *context) override; + + atb::Status Execute(const atb::VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize, + atb::Context *context) override; + +private: + + aclError CheckAcl(aclError ret) const; + std::string opName_; +}; +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/attention/attention_edge.cpp b/tests/proftest/layer_test_framework/operations/fusion/attention/attention_edge.cpp new file mode 100644 index 00000000..16c10ddf --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/attention/attention_edge.cpp @@ -0,0 +1,841 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/norm/norm_linear.h" +#include "operations/fusion/attention/attention_edge.h" + +#include "atb_speed/log.h" +#include "atb_speed/utils/check_util.h" + +namespace atb_speed { + namespace common { + + static const uint64_t ROTARY_COEFF = 2; + + std::map> GetAttnInTensorCandidatesEdge() + { + std::map> attnInTensorCandidates = { + {"default", { + "in_hidden_states", "in_input_norm_weight", "in_qkv_weight", + "in_attention_out_weight", "in_mlp_weight_0", "in_mlp_down_weight", + "in_post_attention_norm_weight", "in_attention_mask", "in_position_id", + "in_cos_emb", "in_sin_emb", "in_seq_len", "in_place_holder", "in_past_key", + "in_past_value"} + }, + {"hasbias", {"in_qkv_bias"} + }, + {"qk_norm", {"in_q_norm_weight", "in_k_norm_weight"}} + }; + return attnInTensorCandidates; + } + + std::map> GetQuantAttnInTensorCandidatesEdge() + { + std::map> attnInTensorCandidates = { + {"default", { + "in_hidden_states", "in_input_norm_weight", "in_qkv_weight", "in_qkv_weight_input_scale", + "in_qkv_weight_input_offset", "in_qkv_weight_deq_scale", "in_qkv_weight_quant_bias", + "in_attention_out_weight", "in_attention_out_weight_input_scale", + "in_attention_out_weight_input_offset", + "in_attention_out_weight_deq_scale", "in_attention_out_weight_quant_bias", "in_attention_mask", + "in_position_id", "in_cos_emb", "in_sin_emb", "in_seq_len", "in_place_holder", "in_past_key", + "in_past_value"} + }, + }; + return attnInTensorCandidates; + } + + std::map> GetAttnIntermediateTensorCandidatesEdge() + { + std::map> attnIntermediateTensorCandidates = { + {"default", { + "intermediate_qkv_mixed_linear_out", "internal_q_scaled_out", + "internal_bmm_q_k_out", "internal_attention_scores", + "internal_attention_probs", "internal_k_split", "internal_v_split", + "internal_k_rope", "internal_bmm_v_out", "intermediate_input_norm_out", + "internal_q_split", "internal_q_rope", "internal_q_rope_transpose", + "internal_bmm_v_out_transpose"} + }, + {"decode", { + "internal_k_rope_transpose", "out_present_value_transpose"} + }, + {"gqa", { + "internal_key", "internal_value"} + }, + {"quant", { + "intermediate_qkv_linear_input_quant", "internal_bmm_v_out_quant"} + }, + }; + return attnIntermediateTensorCandidates; + } + + std::map> GetAttnOutTensorCandidatesEdge() + { + std::map> attnOutTensorCandidates = { + {"default", { + "out_attention", "out_present_key", "out_present_value"} + }, + }; + return attnOutTensorCandidates; + } + + std::map ConstructTensorMap(const AttentionParam ¶m, + uint32_t &inTensorNum, + uint32_t &outTensorNum, + uint32_t &internalTensorNum) + { + bool isPrefill = param.isPrefill; + bool isGQA = param.isGQA; + bool isQuant = param.isQuant; + bool hasBias = param.isHasQKVBias; + bool useQKNorm = param.useQKNorm; + + std::map> attnInTensorCandidates = {}; + if (isQuant) { + attnInTensorCandidates = GetQuantAttnInTensorCandidatesEdge(); + } else { + attnInTensorCandidates = GetAttnInTensorCandidatesEdge(); + } + + auto attnIntermediateTensorCandidates = GetAttnIntermediateTensorCandidatesEdge(); + auto attnOutTensorCandidates = GetAttnOutTensorCandidatesEdge(); + + std::vector inTensorList = {}; + std::vector intermediateTensorList = {}; + std::vector outTensorList = {}; + + AddTensorToList(attnInTensorCandidates, "default", inTensorList); + AddTensorToList(attnIntermediateTensorCandidates, "default", intermediateTensorList); + AddTensorToList(attnOutTensorCandidates, "default", outTensorList); + + if (hasBias && !isQuant) { + AddTensorToList(attnInTensorCandidates, "hasbias", inTensorList); + } + if (useQKNorm) { + AddTensorToList(attnInTensorCandidates, "qk_norm", inTensorList); + } + if (!isPrefill) { + AddTensorToList(attnIntermediateTensorCandidates, "decode", intermediateTensorList); + } + if (isGQA) { + AddTensorToList(attnIntermediateTensorCandidates, "gqa", intermediateTensorList); + } + if (isQuant) { + AddTensorToList(attnIntermediateTensorCandidates, "quant", intermediateTensorList); + } + + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + internalTensorNum = intermediateTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, intermediateTensorList); + } + + atb::Status RmsNormNode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + atb::Node inputNormNode; + atb::infer::RmsNormParam rmsNormParam; + rmsNormParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM; + rmsNormParam.normParam.epsilon = param.normEps; + CreateOperation(rmsNormParam, &inputNormNode.operation); + inputNormNode.inTensorIds = { + GetTensorIdx(tensorMap, "in_hidden_states"), GetTensorIdx(tensorMap, "in_input_norm_weight") + }; + inputNormNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_input_norm_out")}; + opGraph.nodes.push_back(inputNormNode); + return atb::NO_ERROR; + } + + atb::Status QuantQkvInput(atb::GraphParam &opGraph, + std::map &tensorMap) + { + atb::Node quantNode; + atb::infer::ElewiseParam quantParam; + quantParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_QUANT_PER_CHANNEL; + CREATE_OPERATION(quantParam, &quantNode.operation); + quantNode.inTensorIds = { GetTensorIdx(tensorMap, "intermediate_input_norm_out"), + GetTensorIdx(tensorMap, "in_qkv_weight_input_scale"), + GetTensorIdx(tensorMap, "in_qkv_weight_input_offset") + }; + quantNode.outTensorIds = { GetTensorIdx(tensorMap, "intermediate_qkv_linear_input_quant") }; + opGraph.nodes.push_back(quantNode); + return atb::NO_ERROR; + } + + atb::Status QkvLinearNode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + atb::Node qkvLinearNode; + atb::infer::LinearParam linearParam; + linearParam.hasBias = false; + linearParam.transposeA = false; + linearParam.transposeB = true; + + if (param.isQuant) { + QuantQkvInput(opGraph, tensorMap); + linearParam.hasBias = true; + linearParam.outDataType = ACL_FLOAT16; + CreateOperation(linearParam, &qkvLinearNode.operation); + qkvLinearNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_qkv_linear_input_quant"), + GetTensorIdx(tensorMap, "in_qkv_weight"), + GetTensorIdx(tensorMap, "in_qkv_weight_quant_bias"), + GetTensorIdx(tensorMap, "in_qkv_weight_deq_scale") + }; + qkvLinearNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_qkv_mixed_linear_out")}; + opGraph.nodes.push_back(qkvLinearNode); + } else { + CreateOperation(linearParam, &qkvLinearNode.operation); + qkvLinearNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_input_norm_out"), GetTensorIdx(tensorMap, "in_qkv_weight") + }; + qkvLinearNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_qkv_mixed_linear_out")}; + opGraph.nodes.push_back(qkvLinearNode); + } + + return atb::NO_ERROR; + } + + atb::Status QkvBiasNode(atb::GraphParam &opGraph, std::map &tensorMap) + { + atb::Node qkvAddBiasNode; + atb::infer::ElewiseParam addParam; + addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(addParam, &qkvAddBiasNode.operation)); + qkvAddBiasNode.inTensorIds = + atb_speed::common::GetTensorIdxList(tensorMap, {"intermediate_qkv_mixed_linear_out", "in_qkv_bias"}); + qkvAddBiasNode.outTensorIds = + atb_speed::common::GetTensorIdxList(tensorMap, {"intermediate_qkv_mixed_linear_out"}); + opGraph.nodes.push_back(qkvAddBiasNode); + return atb::NO_ERROR; + } + + atb::Status QkvSplitMHANode(atb::GraphParam &opGraph, std::map &tensorMap) + { + static const int ATTENTION_GROUPS = 3; + atb::Node splitQkvNode; + atb::infer::SplitParam splitParam = { -1, ATTENTION_GROUPS }; + CREATE_OPERATION(splitParam, &splitQkvNode.operation); + splitQkvNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_qkv_mixed_linear_out")}; + splitQkvNode.outTensorIds = { + GetTensorIdx(tensorMap, "internal_q_split"), GetTensorIdx(tensorMap, "internal_k_split"), + GetTensorIdx(tensorMap, "internal_v_split"), + }; + opGraph.nodes.push_back(splitQkvNode); + + return atb::NO_ERROR; + } + + atb::Status QkvSplitGQANode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + int headSize = param.hiddenSize; + if (param.useQKNorm) { + headSize = CheckIntMulOverFlow(param.numAttentionHeads, param.hiddenSizePerAttentionHead); + } + static const int ATTENTION_GROUPS = param.numAttentionHeads / param.numKeyValueHeads; + atb::Node sliceQNode; + atb::infer::SliceParam sliceQNodeParam; + sliceQNodeParam.offsets = {0, 0, 0}; + sliceQNodeParam.size = {-1, -1, headSize}; + CREATE_OPERATION(sliceQNodeParam, &sliceQNode.operation); + sliceQNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_qkv_mixed_linear_out")}; + sliceQNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_q_split")}; + opGraph.nodes.push_back(sliceQNode); + + atb::Node sliceKVNode; + atb::infer::SliceParam sliceKVNodeParam; + sliceKVNodeParam.offsets = {0, 0, headSize}; + sliceKVNodeParam.size = {-1, -1, (headSize / ATTENTION_GROUPS) * 2 }; + CREATE_OPERATION(sliceKVNodeParam, &sliceKVNode.operation); + sliceKVNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_qkv_mixed_linear_out")}; + sliceKVNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_qkv_mixed_linear_out")}; + opGraph.nodes.push_back(sliceKVNode); + + atb::Node splitKvNode; + atb::infer::SplitParam splitParam2 = { -1, 2 }; + CREATE_OPERATION(splitParam2, &splitKvNode.operation); + splitKvNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_qkv_mixed_linear_out")}; + splitKvNode.outTensorIds = { + GetTensorIdx(tensorMap, "internal_k_split"), + GetTensorIdx(tensorMap, "internal_v_split"), + }; + opGraph.nodes.push_back(splitKvNode); + return atb::NO_ERROR; + } + + atb::Status QKNormNode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + atb::Node qNormNode; + qNormNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "internal_q_split")); + qNormNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_q_norm_weight")); + qNormNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_q_split")}; + qNormNode.inTensorReshapeFuncs.resize(qNormNode.inTensorIds.size()); + + qNormNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 3: 新的shape维度为3 + newShape.dims[0] = oldShape.dims[1]; // 0: bs * seq_len + newShape.dims[1] = oldShape.dims[0] * // 1: num_heads + oldShape.dims[2] / param.hiddenSizePerAttentionHead; // 2: head_dim + newShape.dims[2] = param.hiddenSizePerAttentionHead; // 2: head_dim + }; + + atb::Node kNormNode; + kNormNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "internal_k_split")); + kNormNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_k_norm_weight")); + kNormNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_k_split")}; + kNormNode.inTensorReshapeFuncs.resize(kNormNode.inTensorIds.size()); + + kNormNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 3: 新的shape维度为3 + newShape.dims[0] = oldShape.dims[1]; // 0: seq_len + newShape.dims[1] = oldShape.dims[0] * // 1: num_heads + oldShape.dims[2] / param.hiddenSizePerAttentionHead; // 2: head_dim + newShape.dims[2] = param.hiddenSizePerAttentionHead; // 2: head_dim + }; + + atb::infer::RmsNormParam qkNormParam; + qkNormParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM; + qkNormParam.normParam.epsilon = 1e-6; + CreateOperation(qkNormParam, &qNormNode.operation); + CreateOperation(qkNormParam, &kNormNode.operation); + opGraph.nodes.push_back(qNormNode); + opGraph.nodes.push_back(kNormNode); + + return atb::NO_ERROR; + } + + atb::Status RepeatKVNode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + static const int ATTENTION_GROUPS = param.numAttentionHeads / param.numKeyValueHeads; + atb::Node expandKNode; + atb::infer::RepeatParam expandKParam; + expandKParam.multiples = {1, 1, ATTENTION_GROUPS, 1}; + CreateOperation(expandKParam, &expandKNode.operation); + expandKNode.inTensorIds = {GetTensorIdx(tensorMap, "out_present_key")}; + expandKNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_key")}; + opGraph.nodes.push_back(expandKNode); + + atb::Node expandVNode; + atb::infer::RepeatParam expandVParam; + expandVParam.multiples = {1, 1, ATTENTION_GROUPS, 1}; + CreateOperation(expandVParam, &expandVNode.operation); + expandVNode.inTensorIds = {GetTensorIdx(tensorMap, "out_present_value")}; + expandVNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_value")}; + opGraph.nodes.push_back(expandVNode); + + return atb::NO_ERROR; + } + + atb::Status PermuteNode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + static const uint64_t NUM_KEY_VALUES_HEADS = param.numKeyValueHeads; + + atb::Node permuteQNode; + atb::infer::TransposeParam permuteSeqHnHsParam; + permuteSeqHnHsParam.perm = {0, 2, 1, 3}; + CreateOperation(permuteSeqHnHsParam, &permuteQNode.operation); + permuteQNode.inTensorIds = {GetTensorIdx(tensorMap, "internal_q_rope")}; + permuteQNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_q_rope_transpose")}; + opGraph.nodes.push_back(permuteQNode); + + atb::Node permuteKNode; + CreateOperation(permuteSeqHnHsParam, &permuteKNode.operation); + permuteKNode.inTensorIds = {GetTensorIdx(tensorMap, "internal_k_rope")}; + + if (param.isPrefill) { + permuteKNode.outTensorIds = {GetTensorIdx(tensorMap, "out_present_key")}; + } else { + permuteKNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_k_rope_transpose")}; + } + opGraph.nodes.push_back(permuteKNode); + + atb::Node permuteVNode; + CREATE_OPERATION(permuteSeqHnHsParam, &permuteVNode.operation); + permuteVNode.inTensorIds = {GetTensorIdx(tensorMap, "internal_v_split")}; + if (param.isPrefill) { + permuteVNode.outTensorIds = {GetTensorIdx(tensorMap, "out_present_value")}; + } else { + permuteVNode.outTensorIds = {GetTensorIdx(tensorMap, "out_present_value_transpose")}; + } + + permuteVNode.inTensorReshapeFuncs.resize(permuteVNode.inTensorIds.size()); + permuteVNode.inTensorReshapeFuncs[0] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 4; // 扩展维度为4维 + newShape.dims[0] = oldShape.dims[0]; // 0: bs + newShape.dims[1] = oldShape.dims[1]; // 1: seq_len + newShape.dims[2] = NUM_KEY_VALUES_HEADS; // 2: 第2维度按头数切分 + newShape.dims[3] = oldShape.dims[2] / NUM_KEY_VALUES_HEADS; // 3: 第3维度将旧维度除以头数 + }; + opGraph.nodes.push_back(permuteVNode); + return atb::NO_ERROR; + } + + atb::Status ConcatNode(atb::GraphParam &opGraph, std::map &tensorMap) + { + atb::Node concatKeyNode; + atb::infer::ConcatParam concatKeyParam = {2}; + CreateOperation(concatKeyParam, &concatKeyNode.operation); + concatKeyNode.inTensorIds = { + GetTensorIdx(tensorMap, "in_past_key"), GetTensorIdx(tensorMap, "internal_k_rope_transpose") + }; + concatKeyNode.outTensorIds = {GetTensorIdx(tensorMap, "out_present_key")}; + opGraph.nodes.push_back(concatKeyNode); + + atb::Node concatValueNode; + atb::infer::ConcatParam concatValueParam = {2}; + CreateOperation(concatValueParam, &concatValueNode.operation); + concatValueNode.inTensorIds = { + GetTensorIdx(tensorMap, "in_past_value"), GetTensorIdx(tensorMap, "out_present_value_transpose") + }; + concatValueNode.outTensorIds = {GetTensorIdx(tensorMap, "out_present_value")}; + opGraph.nodes.push_back(concatValueNode); + return atb::NO_ERROR; + } + + atb::Status ReshapeRopeNode(const AttentionParam ¶m, atb::Node &ropeNode) + { + static const uint64_t NUM_ATTENTION_HEADS = param.numAttentionHeads; + static const uint64_t NUM_KEY_VALUE_HEADS = param.numKeyValueHeads; + ropeNode.inTensorReshapeFuncs.resize(ropeNode.inTensorIds.size()); + if (param.useQKNorm) { + ropeNode.inTensorReshapeFuncs[0] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 4; // 扩展维度为4维 + newShape.dims[0] = 1; // 0: new dim for bs + newShape.dims[1] = oldShape.dims[0]; // 1: seqlen + newShape.dims[2] = oldShape.dims[1]; // 2: num_heads + newShape.dims[3] = oldShape.dims[2]; // 3: head_dim + }; + ropeNode.inTensorReshapeFuncs[1] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 4; // 扩展维度为4维 + newShape.dims[0] = 1; // 0: new dim for bs + newShape.dims[1] = oldShape.dims[0]; // 1: seqlen + newShape.dims[2] = oldShape.dims[1]; // 2: num_heads + newShape.dims[3] = oldShape.dims[2]; // 3: head_dim + }; + } else { + ropeNode.inTensorReshapeFuncs[0] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 4; // 扩展维度为4维 + newShape.dims[0] = oldShape.dims[0]; // 0: bs + newShape.dims[1] = oldShape.dims[1]; // 1: seqlen + newShape.dims[2] = NUM_ATTENTION_HEADS; // 2: 将第2维度按头数切分 + newShape.dims[3] = oldShape.dims[2] / NUM_ATTENTION_HEADS; // 3: 将旧维度2的大小除以注意力头数 + }; + ropeNode.inTensorReshapeFuncs[1] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 4; // 扩展维度为4维 + newShape.dims[0] = oldShape.dims[0]; // 0: 扩展维度 + newShape.dims[1] = oldShape.dims[1]; // 1: seqlen + newShape.dims[2] = NUM_KEY_VALUE_HEADS; // 2: 将第2维度按头数切分 + newShape.dims[3] = oldShape.dims[2] / NUM_KEY_VALUE_HEADS; // 3: 将旧维度2的大小除以注意力头数 + }; + } + return atb::NO_ERROR; + } + + atb::Status RopeNode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + atb::Node ropeNode; + atb::infer::RopeParam ropeParam; + ropeParam.rotaryCoeff = ROTARY_COEFF; + CREATE_OPERATION(ropeParam, &ropeNode.operation); + ropeNode.inTensorIds = { + GetTensorIdx(tensorMap, "internal_q_split"), GetTensorIdx(tensorMap, "internal_k_split"), + GetTensorIdx(tensorMap, "in_cos_emb"), GetTensorIdx(tensorMap, "in_sin_emb"), + GetTensorIdx(tensorMap, "in_seq_len"), + }; + ropeNode.outTensorIds = { + GetTensorIdx(tensorMap, "internal_q_rope"), + GetTensorIdx(tensorMap, "internal_k_rope"), + }; + ReshapeRopeNode(param, ropeNode); + opGraph.nodes.push_back(ropeNode); + PermuteNode(param, opGraph, tensorMap); + if (!param.isPrefill) { + ConcatNode(opGraph, tensorMap); + } + return atb::NO_ERROR; + } + + atb::Status ReshapeBmmQKNode(const AttentionParam ¶m, atb::Node &bmmQKNode) + { + static const bool IS_GQA = param.isGQA; + static const int ATTENTION_GROUPS = param.numAttentionHeads / param.numKeyValueHeads; + + bmmQKNode.inTensorReshapeFuncs.resize(bmmQKNode.inTensorIds.size()); + bmmQKNode.inTensorReshapeFuncs[0] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 扩展维度为3维 + + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; // 0: 合轴 + newShape.dims[1] = oldShape.dims[2]; // 1: 使旧维度的第2维度赋值给新维度的第1维 + newShape.dims[2] = oldShape.dims[3]; // 2:使旧维度的第3维度赋值给新维度的第2维 + // } + }; + bmmQKNode.inTensorReshapeFuncs[1] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 定义新维度为3 + if (IS_GQA) { + newShape.dims[0] = oldShape.dims[1] * ATTENTION_GROUPS; // 0: 合轴 + newShape.dims[1] = oldShape.dims[2] / ATTENTION_GROUPS; // 1:dims[2] / ATTENTION_GROUPS + newShape.dims[2] = oldShape.dims[3] ; // 2:使旧维度的第3维度赋值给新维度的第2维 + } else { + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; // 0: 合轴 + newShape.dims[1] = oldShape.dims[2]; // 1: 使旧维度的第2维度赋值给新维度的第1维 + newShape.dims[2] = oldShape.dims[3]; // 2: 使旧维度的第3维度赋值给新维度的第2维 + } + }; + return atb::NO_ERROR; + } + + atb::Status ReshapeAddMaskNode(const AttentionParam ¶m, atb::Node &addMaskNode) + { + addMaskNode.inTensorReshapeFuncs.resize(addMaskNode.inTensorIds.size()); + if (param.useQKNorm) { + static const uint64_t NUM_KEY_VALUE_HEADS = param.numKeyValueHeads; + addMaskNode.inTensorReshapeFuncs[0] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 4; // 扩展维度为4维 + newShape.dims[0] = oldShape.dims[0] / NUM_KEY_VALUE_HEADS; // 0: 扩展维度 + newShape.dims[1] = NUM_KEY_VALUE_HEADS; // 1: 按头数切分 + newShape.dims[2] = oldShape.dims[1]; // 2: 使旧维度的第2维度赋值给新维度的第1维 + newShape.dims[3] = oldShape.dims[2]; // 3: 使旧维度的第3维度赋值给新维度的第2维 + }; + addMaskNode.inTensorReshapeFuncs[1] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 4; // 扩展维度为4维 + newShape.dims[0] = 1; // 0: 扩展维度 + newShape.dims[1] = 1; // 1: 扩展维度 + newShape.dims[2] = oldShape.dims[0]; // 2: 使旧维度的第2维度赋值给新维度的第0维 + newShape.dims[3] = oldShape.dims[1]; // 3: 使旧维度的第3维度赋值给新维度的第1维 + }; + } else { + static const uint64_t NUM_ATTENTION_HEADS = param.numAttentionHeads; + addMaskNode.inTensorReshapeFuncs[0] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 4; // 扩展维度为4维 + newShape.dims[0] = oldShape.dims[0] / NUM_ATTENTION_HEADS; // 0: 扩展维度 + newShape.dims[1] = NUM_ATTENTION_HEADS; // 1: 按头数切分 + newShape.dims[2] = oldShape.dims[1]; // 2: 使旧维度的第2维度赋值给新维度的第1维 + newShape.dims[3] = oldShape.dims[2]; // 3: 使旧维度的第3维度赋值给新维度的第2维 + }; + addMaskNode.inTensorReshapeFuncs[1] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 4; // 扩展维度为4维 + newShape.dims[0] = 1; // 0: 扩展维度 + newShape.dims[1] = 1; // 1: 扩展维度 + newShape.dims[2] = oldShape.dims[0]; // 2: 使旧维度的第2维度赋值给新维度的第0维 + newShape.dims[3] = oldShape.dims[1]; // 3: 使旧维度的第3维度赋值给新维度的第1维 + }; + } + return atb::NO_ERROR; + } + + atb::Status AttentionScoreNode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + static const uint64_t HEAD_DIM = param.useQKNorm ? + param.hiddenSizePerAttentionHead : param.hiddenSize / param.numAttentionHeads; + atb::Node bmmQKNode; + atb::infer::LinearParam matmulParam; + matmulParam.hasBias = false; + matmulParam.transposeA = false; + matmulParam.transposeB = true; + CreateOperation(matmulParam, &bmmQKNode.operation); + + if (!param.isGQA) { + bmmQKNode.inTensorIds = { + GetTensorIdx(tensorMap, "internal_q_rope_transpose"), + GetTensorIdx(tensorMap, "out_present_key"), + }; + } else { + bmmQKNode.inTensorIds = { + GetTensorIdx(tensorMap, "internal_q_rope_transpose"), + GetTensorIdx(tensorMap, "internal_key"), + }; + } + bmmQKNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_bmm_q_k_out")}; + ReshapeBmmQKNode(param, bmmQKNode); + opGraph.nodes.push_back(bmmQKNode); + + atb::Node mulsQNode; + float scalingAttr = 1.0 / sqrt(HEAD_DIM); + atb::infer::ElewiseParam scalingElewiseMulsParam; + scalingElewiseMulsParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MULS; + scalingElewiseMulsParam.mulsParam.varAttr = scalingAttr; + CreateOperation(scalingElewiseMulsParam, &mulsQNode.operation); + mulsQNode.inTensorIds = {GetTensorIdx(tensorMap, "internal_bmm_q_k_out")}; + mulsQNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_q_scaled_out")}; + opGraph.nodes.push_back(mulsQNode); + return atb::NO_ERROR; + } + + atb::Status AttentionMaskNode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + atb::Node addMaskNode; + atb::infer::ElewiseParam addParam; + addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CreateOperation(addParam, &addMaskNode.operation); + addMaskNode.inTensorIds = { + GetTensorIdx(tensorMap, "internal_q_scaled_out"), + GetTensorIdx(tensorMap, "in_attention_mask"), + }; + addMaskNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_attention_scores")}; + ReshapeAddMaskNode(param, addMaskNode); + opGraph.nodes.push_back(addMaskNode); + return atb::NO_ERROR; + } + + atb::Status SoftMaxNode(atb::GraphParam &opGraph, + std::map &tensorMap) + { + atb::Node softMaxNode; + atb::infer::SoftmaxParam softmaxParam = {{-1}}; + CreateOperation(softmaxParam, &softMaxNode.operation); + softMaxNode.inTensorIds = {GetTensorIdx(tensorMap, "internal_attention_scores")}; + softMaxNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_attention_probs")}; + opGraph.nodes.push_back(softMaxNode); + return atb::NO_ERROR; + } + + atb::Status ReshapeValueLinearNode(const AttentionParam ¶m, atb::Node &bmmVNode) + { + static const int64_t NUM_ATTENTION_HEADS = param.numAttentionHeads; + static const int ATTENTION_GROUPS = param.numAttentionHeads / param.numKeyValueHeads; + static const int HEAD_DIM = param.hiddenSizePerAttentionHead; + bmmVNode.inTensorReshapeFuncs.resize(bmmVNode.inTensorIds.size()); + bmmVNode.inTensorReshapeFuncs[0] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 扩展维度为3维 + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; // 0: 合轴 + newShape.dims[1] = oldShape.dims[2]; // 1: 使旧维度的第2维度赋值给新维度的第1维 + newShape.dims[2] = oldShape.dims[3]; // 2: 使旧维度的第3维度赋值给新维度的第2维 + }; + if (param.useQKNorm) { + bmmVNode.inTensorReshapeFuncs[1] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 扩展维度为3维 + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; // 0: 合轴 + newShape.dims[1] = oldShape.dims[2]; // 1: 使旧维度的第2维度赋值给新维度的第1维 + newShape.dims[2] = oldShape.dims[3]; // 2: 使旧维度的第3维度赋值给新维度的第2维 + if (NUM_ATTENTION_HEADS != newShape.dims[0]) { + newShape.dims[0] = newShape.dims[0] * ATTENTION_GROUPS; // 0: keep num_head consistent + newShape.dims[1] = newShape.dims[1] / ATTENTION_GROUPS; // 1: keep num_head consistent + } + if (HEAD_DIM != newShape.dims[2]) { // 2: modify HEAD_DIM with ATTENTION_GROUPS + newShape.dims[0] = newShape.dims[0] / ATTENTION_GROUPS; // 0: keep head_dim consistent + newShape.dims[2] = newShape.dims[2] * ATTENTION_GROUPS; // 2: keep head_dim consistent + } + }; + } else { + bmmVNode.inTensorReshapeFuncs[1] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 扩展维度为3维 + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + newShape.dims[1] = oldShape.dims[2]; // 1: 使旧维度的第2维度赋值给新维度的第1维 + newShape.dims[2] = oldShape.dims[3]; // 2: 使旧维度的第3维度赋值给新维度的第2维 + if (NUM_ATTENTION_HEADS != newShape.dims[0]) { + newShape.dims[0] = newShape.dims[0] * ATTENTION_GROUPS; // 0: keep num_head consistent + newShape.dims[1] = newShape.dims[1] / ATTENTION_GROUPS; // 1: keep num_head consistent + } + }; + } + return atb::NO_ERROR; + } + + atb::Status ValueLinearNode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + atb::Node bmmVNode; + atb::infer::LinearParam matmulParam2; + matmulParam2.hasBias = false; + matmulParam2.transposeA = false; + matmulParam2.transposeB = false; + CreateOperation(matmulParam2, &bmmVNode.operation); + if (param.isGQA) { + bmmVNode.inTensorIds = { + GetTensorIdx(tensorMap, "internal_attention_probs"), + GetTensorIdx(tensorMap, "internal_value"), + }; + } else { + bmmVNode.inTensorIds = { + GetTensorIdx(tensorMap, "internal_attention_probs"), + GetTensorIdx(tensorMap, "out_present_value"), + }; + } + bmmVNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_bmm_v_out")}; + ReshapeValueLinearNode(param, bmmVNode); + opGraph.nodes.push_back(bmmVNode); + return atb::NO_ERROR; + } + + atb::Status QuantValueInput(atb::GraphParam &opGraph, std::map &tensorMap) + { + atb::Node quantValueNode; + atb::infer::ElewiseParam quantParam; + quantParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_QUANT_PER_CHANNEL; + CREATE_OPERATION(quantParam, &quantValueNode.operation); + + quantValueNode.inTensorIds = { + GetTensorIdx(tensorMap, "internal_bmm_v_out_transpose"), + GetTensorIdx(tensorMap, "in_attention_out_weight_input_scale"), + GetTensorIdx(tensorMap, "in_attention_out_weight_input_offset") + }; + + quantValueNode.outTensorIds = { GetTensorIdx(tensorMap, "internal_bmm_v_out_quant") }; + opGraph.nodes.push_back(quantValueNode); + return atb::NO_ERROR; + } + atb::Status OutLinearPermuteNode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + static const uint64_t NUM_ATTENTION_HEADS = param.numAttentionHeads; + atb::Node permuteAttentionNode; + atb::infer::TransposeParam permuteParam; + permuteParam.perm = {0, 2, 1, 3}; + CREATE_OPERATION(permuteParam, &permuteAttentionNode.operation); + permuteAttentionNode.inTensorIds = {GetTensorIdx(tensorMap, "internal_bmm_v_out")}; + permuteAttentionNode.outTensorIds = {GetTensorIdx(tensorMap, "internal_bmm_v_out_transpose")}; + permuteAttentionNode.inTensorReshapeFuncs.resize(permuteAttentionNode.inTensorIds.size()); + permuteAttentionNode.inTensorReshapeFuncs[0] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 4; // 扩展维度为4维 + if (oldShape.dims[0] != static_cast(NUM_ATTENTION_HEADS)) { + newShape.dims[0] = 1; // 0: origin + newShape.dims[1] = oldShape.dims[0]; // 1: 使旧维度的第0维度赋值给新维度的第1维 + } else { + newShape.dims[0] = oldShape.dims[0] / NUM_ATTENTION_HEADS; // 0: 扩展维度 + newShape.dims[1] = NUM_ATTENTION_HEADS; // 1: 按头数切分 + } + newShape.dims[2] = oldShape.dims[1]; // 2: 使旧维度的第2维度赋值给新维度的第1维 + newShape.dims[3] = oldShape.dims[2]; // 3: 使旧维度的第3维度赋值给新维度的第2维 + }; + opGraph.nodes.push_back(permuteAttentionNode); + return atb::NO_ERROR; + } + + + atb::Status OutLinearNode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + OutLinearPermuteNode(param, opGraph, tensorMap); + + atb::Node outLinearNode; + atb::infer::LinearParam outLinearParam; + outLinearParam.hasBias = false; + outLinearParam.transposeA = false; + outLinearParam.transposeB = true; + + if (param.isQuant) { + QuantValueInput(opGraph, tensorMap); + outLinearParam.hasBias = true; + outLinearParam.outDataType = ACL_FLOAT16; + CreateOperation(outLinearParam, &outLinearNode.operation); + outLinearNode.inTensorIds = { + GetTensorIdx(tensorMap, "internal_bmm_v_out_quant"), + GetTensorIdx(tensorMap, "in_attention_out_weight"), + GetTensorIdx(tensorMap, "in_attention_out_weight_quant_bias"), + GetTensorIdx(tensorMap, "in_attention_out_weight_deq_scale") + }; + } else { + CreateOperation(outLinearParam, &outLinearNode.operation); + outLinearNode.inTensorIds = { + GetTensorIdx(tensorMap, "internal_bmm_v_out_transpose"), + GetTensorIdx(tensorMap, "in_attention_out_weight"), + }; + } + outLinearNode.outTensorIds = {GetTensorIdx(tensorMap, "out_attention")}; + outLinearNode.inTensorReshapeFuncs.resize(outLinearNode.inTensorIds.size()); + + outLinearNode.inTensorReshapeFuncs[0] = [&](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 合并维度为2维 + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; // 0: 将第0维与第1维相乘 + newShape.dims[1] = oldShape.dims[2] * oldShape.dims[3]; // 1: 将第2维与第3维相乘 + }; + opGraph.nodes.push_back(outLinearNode); + return atb::NO_ERROR; + } + + atb::Status AttentionNode(const AttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) + { + RmsNormNode(param, opGraph, tensorMap); + QkvLinearNode(param, opGraph, tensorMap); + if (param.isHasQKVBias && !param.isQuant) { + QkvBiasNode(opGraph, tensorMap); + } + if (param.isGQA) { + QkvSplitGQANode(param, opGraph, tensorMap); + } else { + QkvSplitMHANode(opGraph, tensorMap); + } + if (param.useQKNorm) { + QKNormNode(param, opGraph, tensorMap); + } + RopeNode(param, opGraph, tensorMap); + if (param.isGQA) { + RepeatKVNode(param, opGraph, tensorMap); + } + AttentionScoreNode(param, opGraph, tensorMap); + AttentionMaskNode(param, opGraph, tensorMap); + SoftMaxNode(opGraph, tensorMap); + ValueLinearNode(param, opGraph, tensorMap); + OutLinearNode(param, opGraph, tensorMap); + return atb::NO_ERROR; + } + + atb::Status AttentionEdge(const AttentionParam ¶m, atb::Operation **operation) + { + static const uint64_t KEY_VALUES_HEADS_NUM = param.numKeyValueHeads; + static const uint64_t HIDDEN_SIZE = param.hiddenSize; + + atb::GraphParam opGraph; + opGraph.name = "Attention"; + std::map tensorMap = ConstructTensorMap(param, + opGraph.inTensorNum, + opGraph.outTensorNum, + opGraph.internalTensorNum); + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + if (!param.isPrefill) { + // 将in_past_key赋值到outTensorDescs[1] + outTensorDescs.at(1) = inTensorDescs.at(GetTensorIdx(tensorMap, "in_past_key")); + outTensorDescs.at(1).shape.dims[2] += 1; // dims[2] + 1 + // 将in_past_value赋值到outTensorDescs[2] + outTensorDescs.at(2) = inTensorDescs.at(GetTensorIdx(tensorMap, "in_past_value")); + outTensorDescs.at(2).shape.dims[2] += 1; // dims[2] + 1 + } else { + outTensorDescs.at(1) = inTensorDescs.at(0); + outTensorDescs.at(1).shape.dimNum = 4; // 扩展维度为4维 + outTensorDescs.at(1).shape.dims[0] = inTensorDescs.at(0).shape.dims[0]; // bs + outTensorDescs.at(1).shape.dims[1] = KEY_VALUES_HEADS_NUM; + outTensorDescs.at(1).shape.dims[2] = inTensorDescs.at(0).shape.dims[1]; // 2: 第2维为seqlen + // 将HIDDEN_SIZE / KEY_VALUES_HEADS_NUM赋值到此处的第3维 + outTensorDescs.at(1).shape.dims[3] = HIDDEN_SIZE / KEY_VALUES_HEADS_NUM; + + outTensorDescs.at(2) = inTensorDescs.at(0); // 将inTensorDescs[0]赋值到outTensorDescs[2] + outTensorDescs.at(2).shape.dimNum = 4; // 扩展维度为4维 + outTensorDescs.at(2).shape.dims[0] = inTensorDescs.at(0).shape.dims[0]; // 2: 0维为bs + outTensorDescs.at(2).shape.dims[1] = KEY_VALUES_HEADS_NUM; // 为第2个outTensorDescs第1维定义 + outTensorDescs.at(2).shape.dims[2] = inTensorDescs.at(0).shape.dims[1]; // 2: 第2维为seqlen + // 为outTensor[2]第3维度赋值 + outTensorDescs.at(2).shape.dims[3] = HIDDEN_SIZE / KEY_VALUES_HEADS_NUM; + } + return atb::NO_ERROR; + }; + CHECK_OPERATION_STATUS_RETURN(AttentionNode(param, opGraph, tensorMap)); + CREATE_OPERATION(opGraph, operation); + return atb::NO_ERROR; + } + } // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/attention/attention_edge.h b/tests/proftest/layer_test_framework/operations/fusion/attention/attention_edge.h new file mode 100644 index 00000000..706e4b64 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/attention/attention_edge.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_ATTENTION_H +#define ATB_SPEED_MODELS_ATTENTION_H + +#include "atb/atb_infer.h" +#include "atb_speed/utils/operation_util.h" + +#include "operations/fusion/parallel_layer_v2.h" + +namespace atb_speed { +namespace common { + +struct AttentionParam { + float normEps = 0; /// The epsilon used by the layer normalization layers. + int layerId = 0; /// The current layer Id. + int numHiddenLayers = 0; /// The number of hidden layers. + int numAttentionHeads = 8; /// The number of attention heads. + int numKeyValueHeads = 1; /// The number of key/value heads. + int hiddenSize = 0; /// The size of hidden layers. + int seqLength = 1; // The input sequence length. + bool isPrefill = false; // A flag indicating whether the prefill phase. + bool isGQA = false; /// A flag indicating whether attention type is GQA. + bool isQuant = false; /// A flag indicating whether quantified or not. + bool isHasQKVBias = false; /// A flag indicating whether qkv has bias or not. + bool useQKNorm = false; + int hiddenSizePerAttentionHead = 0; + QuantParam quantParam; /// The parm of quant , it is struct. +}; + +/// This function helps us build an attention based on 310B, It is used only on 310B. +atb::Status AttentionEdge(const AttentionParam ¶m, atb::Operation **operation); + +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/attention/fusion_attention.cpp b/tests/proftest/layer_test_framework/operations/fusion/attention/fusion_attention.cpp new file mode 100644 index 00000000..28acd761 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/attention/fusion_attention.cpp @@ -0,0 +1,636 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/log.h" +#include "atb_speed/utils/check_util.h" +#include "operations/aclnn/ops/dequant_rope_quant_kvcache_operation.h" +#include "operations/fusion/attention/qkv_linear_split.h" +#include "operations/fusion/attention/self_attention.h" +#include "operations/fusion/utils.h" +#include "operations/fusion/infer_shape_functions.h" +#include "operations/fusion/attention/fusion_attention.h" + +namespace atb_speed { +namespace common { + +std::map> GetAttnInTensorCandidates() +{ + std::map> attnInTensorCandidates = { + {"default", { + "in_input", "in_norm_weight", "in_norm_bias", "in_norm_new_weight", "in_norm_new_bias", + "in_weight_0", "in_scale_0", "in_offset_0", "in_descale_0", "in_bias_0", "in_compress_idx_0", + "in_weight_1", "in_scale_1", "in_offset_1", "in_descale_1", "in_bias_1", "in_compress_idx_1", + "in_weight_2", "in_scale_2", "in_offset_2", "in_descale_2", "in_bias_2", "in_compress_idx_2", + "in_cos_embed", "in_sin_embed", "in_seq_len", "in_k_cache", "in_v_cache", "in_attention_mask", + "in_token_offset", "in_layer_id", "in_block_tables", + "in_slots_in_pa_or_logn_in_fa", + "in_weight_dense", "in_scale_dense", "in_offset_dense", "in_descale_dense", "in_bias_dense", + "in_compress_idx_dense"} + }, + {"alibi_mask_compress", {"in_slopes"}}, + {"compress_head_alibi", {"in_batch_wins", "in_ra_seq_len"}}, + {"compress_head_rope", + {"in_batch_wins", "in_ra_seq_len", "in_pffset_index", "in_ra_offset", "in_reshape_seq_len"}}, + {"speculate", {"in_q_len"}}, + {"kv_quant_scale", + {"in_k_quant_scale", "in_k_dequant_scale", "in_v_quant_scale", "in_v_dequant_scale"} + }, + {"kv_quant_offset", + {"in_k_quant_offset", "in_k_dequant_offset", "in_v_quant_offset", "in_v_dequant_offset"} + }, + {"fa3_quant", + {"in_q_quant_scale", "in_k_quant_scale", "in_v_quant_scale", "in_qk_descale", + "q_offset", "kv_offset", "fa3_v_quant_scale", "fa3_offset"} + }, + {"reduce_quant", + {"in_reduce_quant_scale", "in_reduce_quant_offset", "in_gather_quant_scale", "in_gather_quant_offset"} + }, + {"lora", { + "in_seq_len_cum_sum", "in_lora_a_0", "in_lora_b_0", "in_lora_a_1", "in_lora_b_1", + "in_lora_a_2", "in_lora_b_2", "in_dense_lora_a", "in_dense_lora_b"} + }, + {"lora_with_mask", {"in_im_mask"}}, + {"log_n_scale", {"in_log_n_scale"}}, + {"qk_norm", {"in_q_norm_weight", "in_k_norm_weight"}}, + {"add_norm", {"in_residual_add"}}, + {"add_rmsnorm_quant", {"in_qkv_scale_fill", "in_qkv_offset_fill"}}, + {"cmo_mlp_first_matmul_weight", {"in_mlp_weight_0"}}, + {"flash_comm", { + "send_counts", "sdispls", "send_count", "recv_counts", "rdispls", "recv_count", + "fake_rs_shape", "fake_ag_shape"} + }, + }; + return attnInTensorCandidates; +} + +std::map> GetAttnIntermediateTensorCandidates() +{ + std::map> attnIntermediateTensorCandidates = { + {"default", + {"intermediate_q", "intermediate_k", "intermediate_v", "intermediate_self_attention"} + }, + {"kv_quant_scale", + {"intermediate_k_int8", "intermediate_v_int8"} + }, + {"q_quant_scale", + {"intermediate_q_int8"} + }, + {"dequant_rope", + {"intermediate_qkv_rope"} + } + }; + return attnIntermediateTensorCandidates; +} + +template +atb::Status ConstructAttentionQuantTensorMap( + const FusionAttentionParam ¶m, + std::map> &attnInTensorCandidates, + std::map> &attnIntermediateTensorCandidates, + std::vector &inTensorList, std::vector &intermediateTensorList) +{ + // 添加KV cache int8特性的Tensor + if (param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION) { + AddTensorToList(attnInTensorCandidates, "kv_quant_scale", inTensorList); + if (!param.enableRopeQuantKvcache) { + AddTensorToList(attnIntermediateTensorCandidates, "kv_quant_scale", intermediateTensorList); + } + if (param.pageAttentionParam.hasQuantOffset) { + AddTensorToList(attnInTensorCandidates, "kv_quant_offset", inTensorList); + } + } + + // 添加FA3特性的Tensor + if (!param.isPrefill && param.pageAttentionParam.quantType == \ + atb::infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE) { + AddTensorToList(attnIntermediateTensorCandidates, "q_quant_scale", intermediateTensorList); + } + if (param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE) { + AddTensorToList(attnInTensorCandidates, "fa3_quant", inTensorList); + AddTensorToList(attnIntermediateTensorCandidates, "kv_quant_scale", intermediateTensorList); + } + return atb::NO_ERROR; +} + + +template +std::map ConstructTensorMap( + const FusionAttentionParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum) +{ + auto attnInTensorCandidates = GetAttnInTensorCandidates(); + auto attnIntermediateTensorCandidates = GetAttnIntermediateTensorCandidates(); + + std::vector inTensorList = {}; + std::vector intermediateTensorList = {}; + std::vector outTensorList = {"out"}; + + // 添加默认的Tensor + AddTensorToList(attnInTensorCandidates, "default", inTensorList); + // 添加AddRmsNormQuant特性的Tensor + if (param.enableAddNorm) { + AddTensorToList(attnInTensorCandidates, "add_rmsnorm_quant", inTensorList); + } + AddTensorToList(attnIntermediateTensorCandidates, "default", intermediateTensorList); + + // 添加Mask Alibi特性的Tensor + if ( + param.selfAttentionParam.maskType == atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS || \ + param.selfAttentionParam.maskType == atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS_SQRT || \ + param.selfAttentionParam.maskType == atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN + ) { + AddTensorToList(attnInTensorCandidates, "alibi_mask_compress", inTensorList); + } + + // 添加多头压缩特性的Tensor + if (param.pageAttentionParam.compressType == atb::infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD) { + AddTensorToList(attnInTensorCandidates, "compress_head_alibi", inTensorList); + } else if (param.pageAttentionParam.compressType == + atb::infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD_ROPE) { + AddTensorToList(attnInTensorCandidates, "compress_head_rope", inTensorList); + } + // 添加并行解码特性的Tensor + if (param.pageAttentionParam.calcType == atb::infer::PagedAttentionParam::CalcType::CALC_TYPE_SPEC) { + AddTensorToList(attnInTensorCandidates, "speculate", inTensorList); + } + + ConstructAttentionQuantTensorMap(param, attnInTensorCandidates, attnIntermediateTensorCandidates, + inTensorList, intermediateTensorList); + + // 添加lora特性的Tensor + if (param.supportLora) { + if (param.useImMask) { + AddTensorToList(attnInTensorCandidates, "lora_with_mask", inTensorList); + } + AddTensorToList(attnInTensorCandidates, "lora", inTensorList); + } + // 添加lccl all reduce int8特性的Tensor + if (param.selfOutLinearTensorParallelInfo.quantType != \ + atb::infer::AllReduceParam::QuantType::QUANT_TYPE_UNDEFINED) { + AddTensorToList(attnInTensorCandidates, "reduce_quant", inTensorList); + } + + // 添加logN attention特性 + if (param.pageAttentionParam.scaleType == atb::infer::PagedAttentionParam::ScaleType::SCALE_TYPE_LOGN) { + AddTensorToList(attnInTensorCandidates, "log_n_scale", inTensorList); + } + + // 添加 qk_norm 的 Tensor + if (param.useQKNorm) { + AddTensorToList(attnInTensorCandidates, "qk_norm", inTensorList); + } + + // 添加add norm融合的Tensor + if (param.enableAddNorm) { + AddTensorToList(attnInTensorCandidates, "add_norm", inTensorList); + } + + if (param.enableAddNorm) { + outTensorList.push_back("out_add"); + } + + // 添加cmo特性的Tensor + if (param.enablePreFetchWeight) { + AddTensorToList(attnInTensorCandidates, "cmo_mlp_first_matmul_weight", inTensorList); + } + + // 添加 dequant rope tensor + if (param.enableRopeQuantKvcache) { + AddTensorToList(attnIntermediateTensorCandidates, "dequant_rope", intermediateTensorList); + } + + // Add flashcomm1.0 Tensor + if (param.enableFlashComm) { + AddTensorToList(attnInTensorCandidates, "flash_comm", inTensorList); + } + + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + internalTensorNum = intermediateTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, intermediateTensorList); +} + +template +atb::Status AddFAttnQKVLinearSplitNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node qkvLinearSplitNode; + CHECK_OPERATION_STATUS_RETURN(QKVLinearSplit(param, &qkvLinearSplitNode.operation)); + std::vector qkvInTensorNames = { + "in_input", "in_norm_weight", "in_norm_bias", "in_norm_new_weight", "in_norm_new_bias", + "in_weight_0", "in_scale_0", "in_offset_0", "in_descale_0", "in_bias_0", "in_compress_idx_0", + "in_weight_1", "in_scale_1", "in_offset_1", "in_descale_1", "in_bias_1", "in_compress_idx_1", + "in_weight_2", "in_scale_2", "in_offset_2", "in_descale_2", "in_bias_2", "in_compress_idx_2", + }; + // 添加AddRmsNormQuant特性的Tensor + if (param.enableAddNorm) { + qkvInTensorNames.push_back("in_qkv_scale_fill"); + qkvInTensorNames.push_back("in_qkv_offset_fill"); + qkvInTensorNames.push_back("in_residual_add"); + } + if (param.supportLora) { + if (param.useImMask) { + qkvInTensorNames.push_back("in_im_mask"); + } + qkvInTensorNames.push_back("in_seq_len_cum_sum"); + qkvInTensorNames.push_back("in_lora_a_0"); + qkvInTensorNames.push_back("in_lora_b_0"); + qkvInTensorNames.push_back("in_lora_a_1"); + qkvInTensorNames.push_back("in_lora_b_1"); + qkvInTensorNames.push_back("in_lora_a_2"); + qkvInTensorNames.push_back("in_lora_b_2"); + } + if (param.useQKNorm) { + qkvInTensorNames.push_back("in_q_norm_weight"); + qkvInTensorNames.push_back("in_k_norm_weight"); + } + if (param.enableFlashComm) { + qkvInTensorNames.push_back("send_counts"); + qkvInTensorNames.push_back("sdispls"); + qkvInTensorNames.push_back("send_count"); + qkvInTensorNames.push_back("recv_counts"); + qkvInTensorNames.push_back("rdispls"); + qkvInTensorNames.push_back("recv_count"); + qkvInTensorNames.push_back("fake_ag_shape"); + } + qkvLinearSplitNode.inTensorIds = GetTensorIdxList(tensorMap, qkvInTensorNames); + std::vector qkvOutTensorNames = {"intermediate_q", "intermediate_k", "intermediate_v"}; + + if (param.enableRopeQuantKvcache) { // 3 -> 1 + qkvOutTensorNames = {"intermediate_qkv_rope"}; + } + if (param.enableAddNorm) { + qkvOutTensorNames.push_back("out_add"); + } + qkvLinearSplitNode.outTensorIds = GetTensorIdxList(tensorMap, qkvOutTensorNames); + opGraph.nodes.push_back(qkvLinearSplitNode); + return atb::NO_ERROR; +} + +template +int64_t AddRopeQuantKvcacheOperation( + atb::GraphParam& opGraph, const FusionAttentionParam& param, + std::map& tensorMap) +{ + atb::Node dequantRopeQuantKvcacheNode; + AclNNDequantRopeQuantKvcacheParam aclnnParam; + + int64_t sizeSpiltsZero = CheckIntMulOverFlow(param.selfAttentionParam.headNum, param.headDim); + int64_t sizeSpiltsOne = CheckIntMulOverFlow(param.selfAttentionParam.kvHeadNum, param.headDim); + aclnnParam.sizeSpilts = {sizeSpiltsZero, sizeSpiltsOne, sizeSpiltsOne}; + + aclnnParam.kvOutput = true; + aclnnParam.quantMode = "static"; + aclnnParam.layout = "BSND"; + LinearQuantType quantType = GetLinearQuantType( + param.packQuantType, param.layerLinearQuantType[Q_LINEAR_INDEX], param.enableNormQuantOp); + aclnnParam.enableDequant = (!param.isPrefill && param.isBF16 && \ + (quantType == LINEAR_W8A8_QUANT || quantType == LINEAR_W8A8_DEQUANT)); + dequantRopeQuantKvcacheNode.operation = new atb_speed::common::DequantRopeQuantKvcacheOperation( + "aclnnDequantRopeQuantKvcacheNode", aclnnParam + ); + dequantRopeQuantKvcacheNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_qkv_rope"), // 1: input_x + GetTensorIdx(tensorMap, "in_cos_embed"), // 2: cos + GetTensorIdx(tensorMap, "in_sin_embed"), // 3: sin + GetTensorIdx(tensorMap, "in_k_cache"), // 4: k_cache + GetTensorIdx(tensorMap, "in_v_cache"), // 5: v_cache + GetTensorIdx(tensorMap, "in_slots_in_pa_or_logn_in_fa"), // 6: indices + GetTensorIdx(tensorMap, "in_k_quant_scale"), // 7: scale_k + GetTensorIdx(tensorMap, "in_v_quant_scale"), // 8: scale_v + GetTensorIdx(tensorMap, "in_k_quant_offset"), // 9: offset_k + GetTensorIdx(tensorMap, "in_v_quant_offset"), // 10: offset_v + }; + if (aclnnParam.enableDequant) { + dequantRopeQuantKvcacheNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_descale_0")); // 11: weight_scale + dequantRopeQuantKvcacheNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_bias_0")); // 12: bias + } + dequantRopeQuantKvcacheNode.outTensorIds = { + GetTensorIdx(tensorMap, "intermediate_q"), // q_out // 1024, 8, 128 + GetTensorIdx(tensorMap, "intermediate_k"), // k_out // 1024, 1, 128 + GetTensorIdx(tensorMap, "intermediate_v"), // v_out // 1024, 1, 128 + }; + opGraph.nodes.push_back(dequantRopeQuantKvcacheNode); + + return atb::NO_ERROR; +} + + +template +atb::Status AddFAttnRopeNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node ropeNode; + atb_speed::common::RotaryPositionEmbeddingParam ropeParam; + ropeParam.rotaryType = param.rotaryType; + ropeParam.isFA = param.isFA; + ropeParam.headDim = param.headDim; + ropeParam.headNum = param.selfAttentionParam.headNum; + ropeParam.kvHeadNum = param.selfAttentionParam.kvHeadNum; + ropeParam.ropeParam = param.ropeParam; + + RotaryPositionEmbedding(ropeParam, &ropeNode.operation); + + ropeNode.inTensorIds = { // [B,S,N,D] PA [BS,ND] + GetTensorIdx(tensorMap, "intermediate_q"), GetTensorIdx(tensorMap, "intermediate_k"), + GetTensorIdx(tensorMap, "in_cos_embed"), GetTensorIdx(tensorMap, "in_sin_embed"), + GetTensorIdx(tensorMap, "in_seq_len") + }; + if (!param.isFA) { + ropeNode.inTensorReshapeFuncs.resize(ropeNode.inTensorIds.size()); + ropeNode.inTensorReshapeFuncs.at(0) = &SqueezeHeadNumHeadDim; + ropeNode.inTensorReshapeFuncs.at(1) = &SqueezeHeadNumHeadDim; + } + ropeNode.outTensorIds = { // FA [B,S,N,D] PA [BS,N,D] + GetTensorIdx(tensorMap, "intermediate_q"), GetTensorIdx(tensorMap, "intermediate_k"), + }; + opGraph.nodes.push_back(ropeNode); + return atb::NO_ERROR; +} + +template +atb::Status AddKVValueQuantNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap, bool isK) +{ + atb::Node kvValueQuantNode; + atb::infer::ElewiseParam kvValueQuantParam; + kvValueQuantParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_QUANT_PER_CHANNEL; + CREATE_OPERATION(kvValueQuantParam, &kvValueQuantNode.operation); + if (isK) { + kvValueQuantNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_k"), GetTensorIdx(tensorMap, "in_k_quant_scale"), + GetTensorIdx(tensorMap, "in_k_quant_offset"), + }; + kvValueQuantNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_k_int8")}; + } else { + kvValueQuantNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_v"), GetTensorIdx(tensorMap, "in_v_quant_scale"), + GetTensorIdx(tensorMap, "in_v_quant_offset"), + }; + kvValueQuantNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_v_int8")}; + } + kvValueQuantNode.inTensorReshapeFuncs.resize(kvValueQuantNode.inTensorIds.size()); + kvValueQuantNode.inTensorReshapeFuncs[1] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + UnsqueezeHeadNumHeadDim(oldShape, newShape, param.selfAttentionParam.kvHeadNum, param.headDim); + }; + kvValueQuantNode.inTensorReshapeFuncs[2] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + UnsqueezeHeadNumHeadDim(oldShape, newShape, param.selfAttentionParam.kvHeadNum, param.headDim); + }; + opGraph.nodes.push_back(kvValueQuantNode); + return atb::NO_ERROR; +} + +atb::Status AddQKVQuantNode(atb::GraphParam &opGraph, std::map &tensorMap, std::string nodeType) +{ + atb::Node qkvQuantNode; + atb::infer::ElewiseParam qkvQuantParam; + qkvQuantParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_QUANT_PER_CHANNEL; + CREATE_OPERATION(qkvQuantParam, &qkvQuantNode.operation); + if (nodeType == "Q") { + qkvQuantNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_q"), GetTensorIdx(tensorMap, "in_q_quant_scale"), + GetTensorIdx(tensorMap, "q_offset") + }; + qkvQuantNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_q_int8")}; + } else if (nodeType == "K") { + qkvQuantNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_k"), GetTensorIdx(tensorMap, "in_k_quant_scale"), + GetTensorIdx(tensorMap, "kv_offset") + }; + qkvQuantNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_k_int8")}; + } else if (nodeType == "V") { + qkvQuantNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_v"), GetTensorIdx(tensorMap, "in_v_quant_scale"), + GetTensorIdx(tensorMap, "kv_offset") + }; + qkvQuantNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_v_int8")}; + } + qkvQuantNode.inTensorReshapeFuncs.resize(qkvQuantNode.inTensorIds.size()); + opGraph.nodes.push_back(qkvQuantNode); + return atb::NO_ERROR; +} + +template +atb::Status AddSelfOutLinearParallelNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node selfOutLinearParallelNode; + atb_speed::common::LinearParallelParam selfOutLinearParam; + if (param.enableFlashComm) { + selfOutLinearParam.parallelType = atb_speed::common::REDUCE_SCATTER; + } else { + selfOutLinearParam.parallelType = atb_speed::common::ROW_PARALLEL; + } + selfOutLinearParam.fusionLinearParam.quantType = GetLinearQuantType( + param.denseQuantType == atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED \ + ? param.packQuantType : param.denseQuantType, + param.layerLinearQuantType[DENSE_LINEAR_INDEX], false, + param.layerLinearDescs[DENSE_LINEAR_INDEX]); + selfOutLinearParam.biasAfterSync = param.selfOutLinearTensorParallelInfo.worldSize > 1 && \ + selfOutLinearParam.fusionLinearParam.quantType == atb_speed::common::LinearQuantType::NO_QUANT && \ + param.selfAttnHasBias; + selfOutLinearParam.fusionLinearParam.isBF16 = param.isBF16; + selfOutLinearParam.fusionLinearParam.hasBias = param.selfAttnHasBias && !selfOutLinearParam.biasAfterSync; + selfOutLinearParam.fusionLinearParam.supportLora = param.supportLora; + selfOutLinearParam.fusionLinearParam.useImMask = param.useImMask; + selfOutLinearParam.fusionLinearParam.loraEnableGMM = param.loraEnableGMM; + selfOutLinearParam.fusionLinearParam.transposeType = param.layerLinearTransposeType[DENSE_LINEAR_INDEX]; + selfOutLinearParam.fusionLinearParam.quantGroupSize = param.quantGroupSize; + selfOutLinearParam.fusionLinearParam.matmulBackend = param.matmulBackend; + selfOutLinearParam.fusionLinearParam.isPrefill = param.isPrefill; + selfOutLinearParam.tensorParallelInfo = param.selfOutLinearTensorParallelInfo; + selfOutLinearParam.supportLcoc = param.supportLcoc; + selfOutLinearParam.enableMC2 = param.enableMC2; + CHECK_OPERATION_STATUS_RETURN(LinearParallel(selfOutLinearParam, &selfOutLinearParallelNode.operation)); + std::vector denseInTensorNames = { + "intermediate_self_attention", "in_weight_dense", "in_scale_dense", "in_offset_dense", "in_descale_dense", + "in_bias_dense", "in_compress_idx_dense" + }; + if (param.supportLora) { + if (param.useImMask) { + denseInTensorNames.push_back("in_im_mask"); + } + denseInTensorNames.push_back("in_seq_len_cum_sum"); + denseInTensorNames.push_back("in_dense_lora_a"); + denseInTensorNames.push_back("in_dense_lora_b"); + } + if (param.selfOutLinearTensorParallelInfo.quantType != \ + atb::infer::AllReduceParam::QuantType::QUANT_TYPE_UNDEFINED) { + denseInTensorNames.push_back("in_reduce_quant_scale"); + denseInTensorNames.push_back("in_reduce_quant_offset"); + denseInTensorNames.push_back("in_gather_quant_scale"); + denseInTensorNames.push_back("in_gather_quant_offset"); + } + if (selfOutLinearParam.parallelType == atb_speed::common::REDUCE_SCATTER) { + denseInTensorNames.push_back("send_counts"); + denseInTensorNames.push_back("sdispls"); + denseInTensorNames.push_back("recv_count"); + denseInTensorNames.push_back("fake_rs_shape"); + } + selfOutLinearParallelNode.inTensorIds = GetTensorIdxList(tensorMap, denseInTensorNames); + if (!param.isFA) { + selfOutLinearParallelNode.inTensorReshapeFuncs.resize(selfOutLinearParallelNode.inTensorIds.size()); + selfOutLinearParallelNode.inTensorReshapeFuncs.at(0) = &SqueezeHeadNumHeadDim; + } + selfOutLinearParallelNode.outTensorIds = {GetTensorIdx(tensorMap, "out")}; + opGraph.nodes.push_back(selfOutLinearParallelNode); + return atb::NO_ERROR; +} + +template +atb::Status AddQScaleNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node qScaleNode; + atb::infer::ElewiseParam qScaleParam; + qScaleParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MULS; + qScaleParam.mulsParam.varAttr = 1.0 / sqrt(param.headDim); + CREATE_OPERATION(qScaleParam, &qScaleNode.operation); + qScaleNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_q")}; + qScaleNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_q")}; + opGraph.nodes.push_back(qScaleNode); + return atb::NO_ERROR; +} + +template +atb::Status Attention(const FusionAttentionParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + opGraph.name = "Attention"; + std::map tensorMap = ConstructTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + ATB_SPEED_LOG_DEBUG("opGraph.inTensorNum " << opGraph.inTensorNum); + ATB_SPEED_LOG_DEBUG("opGraph.outTensorNum " << opGraph.outTensorNum); + ATB_SPEED_LOG_DEBUG("opGraph.internalTensorNum " << opGraph.internalTensorNum); + + if (param.layerLinearDescs.size() != 0 && \ + CheckParamVectorSize(param.layerLinearDescs, DENSE_LINEAR_INDEX + 1) != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR("The size of param.layerLinearDescs is wrong, please check"); + return atb::ERROR_INVALID_PARAM; + } + if (param.layerLinearQuantType.size() != 0 && \ + CheckParamVectorSize(param.layerLinearQuantType, DENSE_LINEAR_INDEX + 1) != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR("The size of param.layerLinearQuantType is wrong, please check"); + return atb::ERROR_INVALID_PARAM; + } + if (CheckParamVectorSize(param.layerLinearTransposeType, DENSE_LINEAR_INDEX + 1) != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR("The size of param.layerLinearTransposeType is wrong, please check"); + return atb::ERROR_INVALID_PARAM; + } + + if (param.enableRopeQuantKvcache) { + // AddQNormLinearNode only, skip others + CHECK_OPERATION_STATUS_RETURN(AddFAttnQKVLinearSplitNode(param, opGraph, tensorMap)); + CHECK_OPERATION_STATUS_RETURN(AddRopeQuantKvcacheOperation(opGraph, param, tensorMap)); + } else { + // QKV Node + CHECK_OPERATION_STATUS_RETURN(AddFAttnQKVLinearSplitNode(param, opGraph, tensorMap)); + + // Rope Node + if (param.rotaryType != RotaryType::NO_ROTARY) { + CHECK_OPERATION_STATUS_RETURN(AddFAttnRopeNode(param, opGraph, tensorMap)); + } + + // QScale Node + if (param.enableQScale) { + CHECK_OPERATION_STATUS_RETURN(AddQScaleNode(param, opGraph, tensorMap)); + } + + bool atbAttentionDequant = param.attnBackend == atb_speed::common::OpBackend::ATB && \ + param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION; + bool aclnnAttentionDequant = param.attnBackend == atb_speed::common::OpBackend::ACLNN && \ + param.aclnnIncreAttentionParam.hasKVQuant; + if (atbAttentionDequant || aclnnAttentionDequant) { + // K Quant + CHECK_OPERATION_STATUS_RETURN(AddKVValueQuantNode(param, opGraph, tensorMap, true)); + // V Quant + CHECK_OPERATION_STATUS_RETURN(AddKVValueQuantNode(param, opGraph, tensorMap, false)); + } + } + + // FA3 QKV Quant Node + if (!param.isPrefill && param.pageAttentionParam.quantType == \ + atb::infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE) { + CHECK_OPERATION_STATUS_RETURN(AddQKVQuantNode(opGraph, tensorMap, "Q")); + } + if (param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE) { + CHECK_OPERATION_STATUS_RETURN(AddQKVQuantNode(opGraph, tensorMap, "K")); + CHECK_OPERATION_STATUS_RETURN(AddQKVQuantNode(opGraph, tensorMap, "V")); + } + + // SelfAttention Node + CHECK_OPERATION_STATUS_RETURN(AddSelfAttention(opGraph, param, tensorMap)); + + // Dense Node + CHECK_OPERATION_STATUS_RETURN(AddSelfOutLinearParallelNode(param, opGraph, tensorMap)); + + opGraph.inferShapeFunc = [=] + (const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + if (param.enableAddNorm) { + outTensorDescs.at(1) = inTensorDescs.at(0); + } + return atb::NO_ERROR; + }; + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + return atb::NO_ERROR; +} + +template atb::Status ConstructAttentionQuantTensorMap( + const FusionAttentionParam ¶m, + std::map> &attnInTensorCandidates, + std::map> &attnIntermediateTensorCandidates, + std::vector &inTensorList, std::vector &intermediateTensorList); +template std::map ConstructTensorMap( + const FusionAttentionParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum); +template atb::Status AddFAttnRopeNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap); +template atb::Status AddKVValueQuantNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap, bool isK); +template atb::Status AddSelfOutLinearParallelNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap); +template atb::Status AddQScaleNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap); +template atb::Status Attention( + const FusionAttentionParam ¶m, + atb::Operation **operation); + +template atb::Status ConstructAttentionQuantTensorMap( + const FusionAttentionParam ¶m, + std::map> &attnInTensorCandidates, + std::map> &attnIntermediateTensorCandidates, + std::vector &inTensorList, std::vector &intermediateTensorList); +template std::map ConstructTensorMap( + const FusionAttentionParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum); +template atb::Status AddFAttnRopeNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap); +template atb::Status AddKVValueQuantNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap, bool isK); +template atb::Status AddSelfOutLinearParallelNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap); +template atb::Status AddQScaleNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap); +template atb::Status Attention( + const FusionAttentionParam ¶m, + atb::Operation **operation); +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/attention/fusion_attention.h b/tests/proftest/layer_test_framework/operations/fusion/attention/fusion_attention.h new file mode 100644 index 00000000..48b319ed --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/attention/fusion_attention.h @@ -0,0 +1,316 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_MODELS_COMMON_ATTENTION_H +#define ATB_SPEED_MODELS_COMMON_ATTENTION_H + +#include +#include +#include +#include "atb_speed/log.h" +#include "atb_speed/utils/operation_util.h" +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/linear/linear_parallel.h" +#include "operations/fusion/norm/norm_linear.h" +#include "operations/fusion/embedding/positional_embedding.h" +#include "operations/aclnn/ops/attn_operation.h" + +namespace atb_speed { +namespace common { + +/// The categories of the FusionAttention's input tensors +/// Input tensors will be arragned according to the order of their categories +enum AttnInTensorCategory : unsigned int { + /// Default tensors + ATTN_DEFAULT = 0, + /// Tensors required by addRmsNormQuant, addRmsNormDynamicQuant + ATTN_ADD_RMS_NORM_QUANT, + /// Tensors required when passing compressed alibi mask + ATTN_ALIBI_MASK_COMPRESS, + /// Tensors required by kv head compression with alibi mask + ATTN_COMPRESS_HEAD_ALIBI, + /// Tensors required by kv head compression with rope + ATTN_COMPRESS_HEAD_ROPE, + ATTN_OMNI, + /// Tensors required by speculation + ATTN_SPECULATE, + /// Tensors required by int8 quantization for the KV cache + ATTN_KV_QUANT_SCALE, + /// Tensors required by int8 quantization for the KV cache + ATTN_KV_QUANT_OFFSET, + /// Tensors required by flash attention 3 + ATTN_FA3, + /// The mask tensor before applying lora adapters + ATTN_LORA_MASK, + /// Tensors needed for LoRA + ATTN_LORA, + /// Tensors required by the quantization of the all reduce operation + ATTN_REDUCE_QUANT, + /// Tensors required when applying logarithmic scaling to the attention + ATTN_LOG_N_SCALE, + ATTN_QK_NORM, + /// Tensors required by add rmsnorm + ATTN_ADD_NORM, + /// Tensors required by CMO + ATTN_CMO, + /// Tensosr required by flashcomm1.0 + ATTN_FC, + /// A flag signifying the end of all categories + ATTN_END +}; + +/// The index of the q linear within the layer +const uint64_t Q_LINEAR_INDEX = 0; +/// The index of the k linear within the layer +const uint64_t K_LINEAR_INDEX = 1; +/// The index of the v linear within the layer +const uint64_t V_LINEAR_INDEX = 2; +/// The index of the dense linear within the layer +const uint64_t DENSE_LINEAR_INDEX = 3; + +/// Parameters for the FusionAttention module +/// \tparam Types of the normalization parameters. Avaliable options are `atb::infer::RmsNormParam` +/// and ``atb::infer::LayerNormParam`. +template +struct FusionAttentionParam { + // QKV linear param + /// A flag indicating whether the model structure is grouped query attention or multi head attention + bool isGroupedQueryAttention = false; + /// When `isBF16` is true, bfloat16 precision is used; otherwise, float16 precision is used. + bool isBF16 = false; + /// A flag indicating whether to reshape before spliting the packed output of the qkv linear operation + bool splitWithStride = false; + /// A flag indicating whether qkv linear has bias + bool qkvHasBias = false; + /// A flag indicating whether normalization is skipped + bool skipNorm = false; + /// A flag indicating whether normalization has bias + bool normHasBias = false; + /// A flag indicating whether to use NormQuant fusion operation + bool enableNormQuantOp = true; + /// A flag indecating whether to prefetch weight + bool enablePreFetchWeight = false; + /// A flag indicating whether lora is enabled. + bool supportLora = false; + /// A flag indicating whether a mask is used before applying lora adapter. + bool useImMask = false; + /// it should be activated when batch inputs include multiple LoRA adapters + bool loraEnableGMM = false; + /// A flag indicating whether using qnorm and knorm. + bool useQKNorm = false; + /// A flag indicating whether RopeQuantKvcache is enabled. + bool enableRopeQuantKvcache = false; + /// The backend of the attention module; refer to `OpBackend` for the supported values + int attnBackend = atb_speed::common::OpBackend::ATB; + /// The group size used for dequantizing the weight tensor in the per-group quantization approach + int quantGroupSize = 0; + /// Indicates the pack type and the quantization type of the qkv linear. + int packQuantType = atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED; + /// Specifies the quantization type for the following linear module: + /// q linear, k linear, v linear, dense linear, gate linear, up linear, and down linear. + std::vector layerLinearQuantType = {}; + /// Specifies the weight description of the following linear module: + /// qkv linear, dense linear, gateup linear and down linear. + std::vector layerLinearDescs = { + common::LinearDesc::INVALID_DESC, common::LinearDesc::INVALID_DESC, common::LinearDesc::INVALID_DESC, + common::LinearDesc::INVALID_DESC, common::LinearDesc::INVALID_DESC, common::LinearDesc::INVALID_DESC, + common::LinearDesc::INVALID_DESC + }; + /// Defines the transpose type of the second matrix in the matmul operation for the following linear module: + /// q linear, k linear, v linear, dense linear, gate linear, up linear, and down linear. + std::vector layerLinearTransposeType = {}; + /// Normalization parameters for float operation + NormParamType normParamType; + /// Normlization parameters for quantization operation + NormParamType normQuantParamType; + // rope param + /// The type of rotary position embedding. Refer to `RotaryType` + /// in the `operations/fusion/positional_embedding.h` for more details. + atb_speed::common::RotaryType rotaryType; + /// Parameters for the rope operation + atb::infer::RopeParam ropeParam; + // self attention param + /// A flag indicating whether to apply logarithmic scaling to the attention + bool enableLogN = false; + bool enableQScale = false; + /// A flag indicating whether split fuse is enabled + bool enableSplitFuse = false; + /// If `isFA` is true, Flash Attention is used; otherwise, Paged Attention is used + bool isFA = true; + /// A flag indicating the prefill and decode phases + bool isPrefill = false; + /// The dimension per attention head + int headDim = 0; + /// Parameters for the self attention operation from the ATB backend + atb::infer::SelfAttentionParam selfAttentionParam; + /// Parameters for the page attention operation from the ATB backend + atb::infer::PagedAttentionParam pageAttentionParam; + /// Parameters for the reshape and cache operation from the ATB backend + atb::infer::ReshapeAndCacheParam reshapeCacheParm; + atb::infer::ReshapeAndCacheOmniParam reshapeCacheOmniParm; + /// Parameters for the attention operation from the AclNN backend (used in the decode phase) + atb_speed::common::AclNNAttnParam aclnnIncreAttentionParam; + // self out linear param + /// A flag indicating whether dense linear has bias + bool selfAttnHasBias = false; + /// A flag that indicates whether low-latency computation over communication is enabled + bool supportLcoc = false; + bool enableMC2 = false; + /// The quantization type of the dense linear. Refer to `PackQuantType` in the `operations/utils.h`. + int denseQuantType = atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED; + /// Details about tensor parallelism + atb_speed::common::TensorParallelInfo selfOutLinearTensorParallelInfo; + /// A flag indicating whether to use the atb matmul backend + int matmulBackend = atb_speed::common::OpBackend::ATB; + /// A flag indicating whether addNorm fusion is enabled in attention + bool enableAddNorm = false; + /// Specifies whether the input norm enables antioutlier + bool isAntiOutlier = false; + /// Specifies whether enabled omni + bool enableOmniattention = false; + bool isomnicompressed = false; + // A flag indicating whether use flashcomm1.0 + bool enableFlashComm = false; + /// A flag indicating whether to use pmcc obfuscation + bool enableModelConfuscation = false; + /// A handle used by pmcc model obfuscation + int32_t modelConfuscationFd = 0; + /// Hidden size per rank + int32_t hiddenSizePerRank = 0; +}; + +template +std::map ConstructTensorMap(const FusionAttentionParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum); + +/// This function is the main entrance for the fusion attention module. +/// It consists of QKVLinearSplit operation, Rope operation, Elementwise operation to quant intermediate kv tensors, +/// ReshapeAndCache operation, SelfAttention/PageAttention operation +/// and LinearParallel operation for the dense linear. +/// +/// \tparam NormParamType Types of the normalization parameters. Avaliable options are `atb::infer::RmsNormParam` +/// and `atb::infer::LayerNormParam`. +/// \param param Parameters for the normalization and linear module +/// \param operation the address of a pointer to a default operation +/// \return A flag indicating whether the operation has been successfully created. +/// +/// Operation's inputs: +/// Name | Requirements | Dtype | Shape | Description | +/// -------------------|--------------|------------------|-------|----------| +/// in_input | Required | float16/bfloat16 | `isFA` is false: [len(all_seq),num_heads,head_dim] | Hidden States | +/// ^ | ^ | ^ | `isFA` is true: [bsz,seq_len,num_heads,head_dim] | ^ | +/// in_norm_weight | ^ | Refer to `atb_speed::common::QKVLinearSplit` in the `operations/fusion/attention/qkv_linear_split.h` for more details. ||| +/// in_norm_bias | ^ | ^ ||| +/// in_norm_new_weight | ^ | ^ ||| +/// in_norm_new_bias | ^ | ^ ||| +/// in_weight_0 | ^ | ^ ||| +/// in_scale_0 | ^ | ^ ||| +/// in_offset_0 | ^ | ^ ||| +/// in_descale_0 | ^ | ^ ||| +/// in_bias_0 | ^ | ^ ||| +/// in_compress_idx_0 | ^ | ^ ||| +/// in_weight_1 | ^ | ^ ||| +/// in_scale_1 | ^ | ^ ||| +/// in_offset_1 | ^ | ^ ||| +/// in_descale_1 | ^ | ^ ||| +/// in_bias_1 | ^ | ^ ||| +/// in_compress_idx_1 | ^ | ^ ||| +/// in_weight_2 | ^ | ^ ||| +/// in_scale_2 | ^ | ^ ||| +/// in_offset_2 | ^ | ^ ||| +/// in_descale_2 | ^ | ^ ||| +/// in_bias_2 | ^ | ^ ||| +/// in_compress_idx_2 | ^ | ^ ||| +/// in_cos_embed | ^ | float16/bfloat16 | [len(all_seq),head_dim] | The cosine part of the rotary embedding. | +/// in_sin_embed | ^ | ^ | ^ | The sine part of the rotary embedding. | +/// in_seq_len | ^ | int32 | [batch_size] | The total number of input and output tokens.
In the prefill phase, each elements equals to the length of the prompt.
For flash attention, each element is set to 1 in the decode phase.
For paged attention, each element is set to the number of input tokens plus output tokens in the decode phase. | +/// in_k_cache | ^ | Refer to `atb_speed::common::AddSelfAttention` in the `operations/fusion/attention/self_attention.h` for more details. ||| +/// in_v_cache | ^ | ^ ||| +/// in_attention_mask | ^ | ^ ||| +/// in_token_offset | ^ | ^ ||| +/// in_layer_id | ^ | ^ ||| +/// in_block_tables | ^ | ^ ||| +/// in_slots_in_pa_or_logn_in_fa | ^ | ^ ||| +/// in_weight_dense | ^ | Weights for the dense linear. Refer to `atb_speed::common::LinearParallel` in the `operations/fusion/linear/linear_parallel.h` for more details. ||| +/// in_scale_dense | ^ | ^ ||| +/// in_offset_dense | ^ | ^ ||| +/// in_descale_dense | ^ | ^ ||| +/// in_bias_dense | ^ | ^ ||| +/// in_compress_idx_dense | ^ | ^ ||| +/// in_slopes | `param.selfAttentionParam.maskType` is in one of `atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS`, `atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS_SQRT` and `atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN` | Refer to `atb_speed::common::AddSelfAttention` in the `operations/fusion/attention/self_attention.h` for more details. ||| +/// in_batch_wins | `param.pageAttentionParam.compressType` equals to `atb::infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD` or `atb::infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD_ROPE` | ^ ||| +/// in_ra_seq_len | ^ | ^ ||| +/// in_pffset_index | `param.pageAttentionParam.compressType == atb::infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD_ROPE` | ^ ||| +/// in_ra_offset | ^ | ^ ||| +/// in_reshape_seq_len | ^ | ^ ||| +/// in_q_len | `param.pageAttentionParam.calcType == atb::infer::PagedAttentionParam::CalcType::CALC_TYPE_SPEC` | ^ ||| +/// in_k_quant_scale | `param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION` | float16/bfloat16 | [head_num * head_dim] | | +/// in_k_dequant_scale | ^ | Refer to `atb_speed::common::AddSelfAttention` in the `operations/fusion/attention/self_attention.h` for more details. ||| +/// in_v_quant_scale | ^ | float16/bfloat16 | [head_num * head_dim] | | +/// in_v_dequant_scale | ^ | Refer to `atb_speed::common::AddSelfAttention` in the `operations/fusion/attention/self_attention.h` for more details. ||| +/// in_k_quant_offset | `param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION` and `param.pageAttentionParam.hasQuantOffset` is true | int8 | [head_num * head_dim] | | +/// in_k_dequant_offset| ^ | Refer to `atb_speed::common::AddSelfAttention` in the `operations/fusion/attention/self_attention.h` for more details. ||| +/// in_v_quant_offset | ^ | int8 | [head_num * head_dim] | | +/// in_v_dequant_offset| ^ | Refer to `atb_speed::common::AddSelfAttention` in the `operations/fusion/attention/self_attention.h` for more details. ||| +/// in_q_quant_scale | `param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE` | float16 | [head_num,head_dim] | | +/// in_k_quant_scale | ^ | float16 | [kv_head_num,head_dim] | | +/// in_v_quant_scale | ^ | float16 | [kv_head_num,head_dim] | | +/// in_qk_descale | ^ | Refer to `atb_speed::common::AddSelfAttention` in the `operations/fusion/attention/self_attention.h` for more details. ||| +/// q_offset | ^ | int8 | [head_num,head_dim] | | +/// kv_offset | ^ | int8 | [kv_head_num,head_dim] | | +/// fa3_v_quant_scale | ^ | Refer to `atb_speed::common::AddSelfAttention` in the `operations/fusion/attention/self_attention.h` for more details. ||| +/// fa3_offset | ^ | int32 | [head_num] | | +/// in_im_mask | `param.supportLora` is true and `param.useImMask` is true | Refer to atb_speed::common::QKVLinearSplit in the `operations/fusion/attention/qkv_linear_split.h` for more details. ||| +/// in_seq_len_cum_sum | `param.supportLora` is true | ^ ||| +/// in_lora_a_0 | ^ | ^ ||| +/// in_lora_b_0 | ^ | ^ ||| +/// in_lora_a_1 | ^ | ^ ||| +/// in_lora_b_1 | ^ | ^ ||| +/// in_lora_a_2 | ^ | ^ ||| +/// in_lora_b_2 | ^ | ^ ||| +/// in_dense_lora_a | ^ | ^ ||| +/// in_dense_lora_b | ^ | ^ ||| +/// in_reduce_quant_scale | `param.selfOutLinearTensorParallelInfo.quantType != atb::infer::AllReduceParam::QuantType::QUANT_TYPE_UNDEFINED` | Refer to `atb_speed::common::LinearParallel` in the `operations/fusion/linear/linear_parallel.h` for more details. ||| +/// in_reduce_quant_offset| ^ | ^ ||| +/// in_gather_quant_scale | ^ | ^ ||| +/// in_gather_quant_offset| ^ | ^ ||| +/// in_log_n_scale | `param.pageAttentionParam.scaleType == atb::infer::PagedAttentionParam::ScaleType::SCALE_TYPE_LOGN` | Refer to `atb_speed::common::AddSelfAttention` in the `operations/fusion/attention/self_attention.h` for more details. ||| +/// in_q_norm_weight | `param.useQKNorm` is true | Refer to atb_speed::common::QKVLinearSplit in the `operations/fusion/attention/qkv_linear_split.h` for more details. ||| +/// in_k_norm_weight | ^ | ^ ||| +/// in_residual_add | `param.enableAddNorm` is true | ^ ||| +/// +/// Operations's outputs: +/// Name | Dtype | Shape | +/// -----------|----------------------|----------------------| +/// out | The same as in_input | The same as in_input | +/// out_add | The same as in_input | The same as in_input | +/// +/// Example: +/// \code +/// atb::Node fusionAttentionNode; +/// atb_speed::common::FusionAttentionParam fusionAttentionParam; +/// // Modify fusionAttentionParam's attribute if needed. +/// Attention(fusionAttentionParam, &fusionAttentionNode.operation); +/// fusionAttentionNode.inTensorIds = {...}; // Passing inputs for the operation in order +/// fusionAttentionNode.outTensorIds = {...}; // Tensor index for out +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(fusionAttentionNode); +/// \endcode +template +atb::Status Attention(const FusionAttentionParam ¶m, atb::Operation **operation); +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/attention/qkv_linear_split.cpp b/tests/proftest/layer_test_framework/operations/fusion/attention/qkv_linear_split.cpp new file mode 100644 index 00000000..12c16ad3 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/attention/qkv_linear_split.cpp @@ -0,0 +1,586 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "atb_speed/utils/check_util.h" +#include "operations/fusion/utils.h" +#include "operations/fusion/infer_shape_functions.h" +#include "operations/fusion/attention/fusion_attention.h" +#include "operations/aclnn/ops/rms_norm_operation.h" +#include "operations/fusion/attention/qkv_linear_split.h" + +namespace atb_speed { +namespace common { + +std::map> GetQKVInTensorCandidates() +{ + std::map> qkvInTensorCandidates = { + {"default", { + "in_qkv_input", "in_qkv_norm_weight", "in_qkv_norm_bias", "in_qkv_norm_new_weight", + "in_qkv_norm_new_bias", + "in_qkv_weight_0", "in_qkv_scale_0", "in_qkv_offset_0", "in_qkv_descale_0", "in_qkv_bias_0", + "in_qkv_compress_idx_0", + "in_qkv_weight_1", "in_qkv_scale_1", "in_qkv_offset_1", "in_qkv_descale_1", "in_qkv_bias_1", + "in_qkv_compress_idx_1", + "in_qkv_weight_2", "in_qkv_scale_2", "in_qkv_offset_2", "in_qkv_descale_2", "in_qkv_bias_2", + "in_qkv_compress_idx_2"} + }, + {"lora", { + "in_seq_len_cum_sum", "in_qkv_lora_a_0", "in_qkv_lora_b_0", + "in_qkv_lora_a_1", "in_qkv_lora_b_1", "in_qkv_lora_a_2", "in_qkv_lora_b_2"} + }, + {"lora_with_mask", {"in_im_mask"}}, + {"qk_norm", {"in_q_norm_weight", "in_k_norm_weight"}}, + {"add_norm", {"in_residual_add"}}, + {"add_rmsnorm_quant", {"in_qkv_scale_fill", "in_qkv_offset_fill"}}, + {"flash_comm", { + "send_counts", "sdispls", "send_count", "recv_counts", "rdispls", "recv_count", "fake_ag_shape"} + }, + }; + return qkvInTensorCandidates; +} + +std::map> GetQKVIntermediateTensorCandidates() +{ + std::map> qkvIntermediateTensorCandidates = { + {"qkv_pack", {"intermediate_qkv"}}, + {"qk_norm", {"intermediate_q", "intermediate_k", "intermediate_q_rstd_out", "intermediate_k_rstd_out"}}, + {"add_norm", {"out_add"}}, + }; + return qkvIntermediateTensorCandidates; +} + +std::map> GetQKVOutTensorCandidates() +{ + std::map> qkvOutTensorCandidates = { + {"default", {"out_q", "out_k", "out_v"}}, + {"add_norm", {"out_add"}}, + {"dequant_rope", {"intermediate_qkv_rope"}}, + }; + return qkvOutTensorCandidates; +} + +template +std::map ConstructQKVTensorMap( + const FusionAttentionParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum) +{ + auto qkvInTensorCandidates = GetQKVInTensorCandidates(); + auto qkvIntermediateTensorCandidates = GetQKVIntermediateTensorCandidates(); + auto qkvOutTensorCandidates = GetQKVOutTensorCandidates(); + + std::vector inTensorList = {}; + std::vector intermediateTensorList = {}; + std::vector outTensorList = {}; + + std::vector qkvLinearIndex = {Q_LINEAR_INDEX, K_LINEAR_INDEX, V_LINEAR_INDEX}; + bool isPack = CheckPack(param.packQuantType, param.layerLinearDescs, qkvLinearIndex); + + // 添加默认的Tensor + AddTensorToList(qkvInTensorCandidates, "default", inTensorList); + // 添加AddRmsNormQuant特性的Tensor + if (param.enableAddNorm) { + AddTensorToList(qkvInTensorCandidates, "add_rmsnorm_quant", inTensorList); + AddTensorToList(qkvInTensorCandidates, "add_norm", inTensorList); + } + if (isPack && !param.enableRopeQuantKvcache) { + AddTensorToList(qkvIntermediateTensorCandidates, "qkv_pack", intermediateTensorList); + if (param.useQKNorm) { + AddTensorToList(qkvIntermediateTensorCandidates, "qk_norm", intermediateTensorList); + AddTensorToList(qkvInTensorCandidates, "qk_norm", inTensorList); + } + } + + // 添加Lora特性的Tensor + if (param.supportLora) { + if (param.useImMask) { + AddTensorToList(qkvInTensorCandidates, "lora_with_mask", inTensorList); + } + AddTensorToList(qkvInTensorCandidates, "lora", inTensorList); + } + + // 添加flashcomm 1.0的Tensor + if (param.enableFlashComm) { + AddTensorToList(qkvInTensorCandidates, "flash_comm", inTensorList); + } + // 添加outTensor + if (param.enableRopeQuantKvcache) { + AddTensorToList(qkvOutTensorCandidates, "dequant_rope", outTensorList); + } else { + AddTensorToList(qkvOutTensorCandidates, "default", outTensorList); + } + if (param.enableAddNorm) { + AddTensorToList(qkvOutTensorCandidates, "add_norm", outTensorList); + } + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + internalTensorNum = intermediateTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, intermediateTensorList); +} + +template +atb::Status AddQNormLinearNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap, bool isAntiOutlier, bool isPack) +{ + atb::Node qNormLinearNode; + atb_speed::common::NormLinearParam qNormLinearParam; + qNormLinearParam.isAntiOutlier = isAntiOutlier; + if (param.layerLinearQuantType.size() != 0 && \ + CheckParamVectorSize(param.layerLinearQuantType, Q_LINEAR_INDEX + 1) != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR("The size of param.layerLinearQuantType is wrong, please check"); + return atb::ERROR_INVALID_PARAM; + } + if (CheckParamVectorSize(param.layerLinearTransposeType, Q_LINEAR_INDEX + 1) != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR("The size of param.layerLinearTransposeType is wrong, please check"); + return atb::ERROR_INVALID_PARAM; + } + qNormLinearParam.fusionLinearParam.quantType = GetLinearQuantType( + param.packQuantType, param.layerLinearQuantType[Q_LINEAR_INDEX], param.enableNormQuantOp, + param.layerLinearDescs[Q_LINEAR_INDEX]); + qNormLinearParam.fusionLinearParam.isBF16 = param.isBF16; + qNormLinearParam.fusionLinearParam.hasBias = param.qkvHasBias; + qNormLinearParam.fusionLinearParam.supportLora = param.supportLora; + qNormLinearParam.fusionLinearParam.useImMask = param.useImMask; + qNormLinearParam.fusionLinearParam.loraEnableGMM = param.loraEnableGMM; + qNormLinearParam.fusionLinearParam.transposeType = param.layerLinearTransposeType[0]; + qNormLinearParam.fusionLinearParam.quantGroupSize = param.quantGroupSize; + qNormLinearParam.fusionLinearParam.matmulBackend = param.matmulBackend; + qNormLinearParam.fusionLinearParam.isThrowDequant = false; + qNormLinearParam.fusionLinearParam.isPrefill = param.isPrefill; + qNormLinearParam.skipNorm = param.skipNorm; + qNormLinearParam.normHasBias = param.normHasBias; + qNormLinearParam.normParamType = param.normParamType; + qNormLinearParam.normQuantParamType = param.normQuantParamType; + qNormLinearParam.enableAddNorm = param.enableAddNorm; + qNormLinearParam.fusionLinearParam.enableFlashComm = param.enableFlashComm; + qNormLinearParam.fusionLinearParam.flashCommParallelInfo.rank = param.selfOutLinearTensorParallelInfo.rank; + qNormLinearParam.fusionLinearParam.flashCommParallelInfo.worldSize = + param.selfOutLinearTensorParallelInfo.worldSize; + qNormLinearParam.fusionLinearParam.flashCommParallelInfo.backend = + param.selfOutLinearTensorParallelInfo.backend; + qNormLinearParam.enableModelConfuscation = param.enableModelConfuscation; + qNormLinearParam.modelConfuscationFd = param.modelConfuscationFd; + qNormLinearParam.hiddenSizePerRank = param.hiddenSizePerRank; + qNormLinearParam.modelObfuscationParallelInfo = param.selfOutLinearTensorParallelInfo; + CHECK_OPERATION_STATUS_RETURN(NormLinear(qNormLinearParam, &qNormLinearNode.operation)); + + std::vector qInTensor = { + "in_qkv_input", "in_qkv_norm_weight", "in_qkv_norm_bias", "in_qkv_norm_new_weight", + "in_qkv_norm_new_bias", + "in_qkv_weight_0", "in_qkv_scale_0", "in_qkv_offset_0", "in_qkv_descale_0", "in_qkv_bias_0", + "in_qkv_compress_idx_0" + }; + if (param.enableAddNorm) { + qInTensor.push_back("in_qkv_scale_fill"); + qInTensor.push_back("in_qkv_offset_fill"); + qInTensor.push_back("in_residual_add"); + } + if (param.supportLora) { + if (param.useImMask) { + qInTensor.push_back("in_im_mask"); + } + qInTensor.push_back("in_seq_len_cum_sum"); + qInTensor.push_back("in_qkv_lora_a_0"); + qInTensor.push_back("in_qkv_lora_b_0"); + } + if (param.enableFlashComm) { + qInTensor.push_back("send_counts"); + qInTensor.push_back("sdispls"); + qInTensor.push_back("send_count"); + qInTensor.push_back("recv_counts"); + qInTensor.push_back("rdispls"); + qInTensor.push_back("recv_count"); + qInTensor.push_back("fake_ag_shape"); + } + qNormLinearNode.inTensorIds = GetTensorIdxList(tensorMap, qInTensor); + std::vector qOutTensor; + if (param.enableRopeQuantKvcache) { + qOutTensor = {"intermediate_qkv_rope"}; + } else { + qOutTensor = {isPack ? "intermediate_qkv" : "out_q"}; + } + if (param.enableAddNorm) { + qOutTensor.push_back("out_add"); + } + qNormLinearNode.outTensorIds = GetTensorIdxList(tensorMap, qOutTensor); + opGraph.nodes.push_back(qNormLinearNode); + return atb::NO_ERROR; +} + +template +atb::Status AddSplitQKVNode(const FusionAttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) +{ + atb::Node splitQKVNode; + atb::infer::SplitParam splitQKVParam; + if (param.splitWithStride) { + splitQKVParam = {2, 3, {param.selfAttentionParam.headNum / param.selfAttentionParam.kvHeadNum, 1, 1}}; + } else { + splitQKVParam = {(param.isFA ? 2 : 1), 3, { + CheckIntMulOverFlow(param.selfAttentionParam.headNum, param.headDim), + CheckIntMulOverFlow(param.selfAttentionParam.kvHeadNum, param.headDim), + CheckIntMulOverFlow(param.selfAttentionParam.kvHeadNum, param.headDim)}}; + } + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(splitQKVParam, &splitQKVNode.operation)); + splitQKVNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_qkv")}; + splitQKVNode.outTensorIds = {GetTensorIdxList(tensorMap, {param.useQKNorm ? "intermediate_q" : "out_q", + param.useQKNorm ? "intermediate_k" : "out_k", "out_v"})}; + if (param.splitWithStride) { + splitQKVNode.inTensorReshapeFuncs.resize(splitQKVNode.inTensorIds.size()); + splitQKVNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + InternlmV2QKVSplit( + oldShape, newShape, + param.selfAttentionParam.headNum, param.selfAttentionParam.kvHeadNum, param.headDim); + }; + } + opGraph.nodes.push_back(splitQKVNode); + return atb::NO_ERROR; +} + +template +atb::Status AddSplitMixedQKVNode(const FusionAttentionParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) +{ + atb::Node splitMixedQKVNode; + atb::infer::SplitParam splitMixedQKVParam; + if (param.splitWithStride) { + splitMixedQKVParam = {-2, 3, {}}; + } else { + splitMixedQKVParam = {-1, 3, {}}; + } + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(splitMixedQKVParam, &splitMixedQKVNode.operation)); + splitMixedQKVNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_qkv")}; + splitMixedQKVNode.outTensorIds = GetTensorIdxList(tensorMap, {"out_q", "out_k", "out_v"}); + if (param.splitWithStride) { + splitMixedQKVNode.inTensorReshapeFuncs.resize(splitMixedQKVNode.inTensorIds.size()); + splitMixedQKVNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + size_t dim = 0; + newShape.dims[dim++] = oldShape.dims[0]; // PA ntokens | FA batch + if (param.isFA) { + newShape.dims[dim++] = oldShape.dims[1]; // FA seqlen + } + newShape.dims[dim++] = param.selfAttentionParam.headNum; // headNum + newShape.dims[dim++] = 3; // 3 -> q, k, v + newShape.dims[dim++] = param.headDim; // dk + newShape.dimNum = dim; + }; + } + opGraph.nodes.push_back(splitMixedQKVNode); + return atb::NO_ERROR; +} + +template +atb::Status AddKNormLinearNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap, bool isAntiOutlier) +{ + atb::Node kNormLinearNode; + atb_speed::common::NormLinearParam kNormLinearParam; + kNormLinearParam.isAntiOutlier = isAntiOutlier; + kNormLinearParam.fusionLinearParam.quantType = GetLinearQuantType( + param.packQuantType, param.layerLinearQuantType[K_LINEAR_INDEX], param.enableNormQuantOp, + param.layerLinearDescs[K_LINEAR_INDEX]); + kNormLinearParam.fusionLinearParam.isBF16 = param.isBF16; + kNormLinearParam.fusionLinearParam.hasBias = param.qkvHasBias; + kNormLinearParam.fusionLinearParam.supportLora = param.supportLora; + kNormLinearParam.fusionLinearParam.useImMask = param.useImMask; + kNormLinearParam.fusionLinearParam.loraEnableGMM = param.loraEnableGMM; + kNormLinearParam.fusionLinearParam.transposeType = param.layerLinearTransposeType[1]; + kNormLinearParam.fusionLinearParam.quantGroupSize = param.quantGroupSize; + kNormLinearParam.fusionLinearParam.matmulBackend = param.matmulBackend; + kNormLinearParam.fusionLinearParam.isPrefill = param.isPrefill; + kNormLinearParam.skipNorm = param.skipNorm; + kNormLinearParam.normHasBias = param.normHasBias; + kNormLinearParam.normParamType = param.normParamType; + kNormLinearParam.normQuantParamType = param.normQuantParamType; + CHECK_OPERATION_STATUS_RETURN(NormLinear(kNormLinearParam, &kNormLinearNode.operation)); + std::vector kInTensor = { + "in_qkv_input", "in_qkv_norm_weight", "in_qkv_norm_bias", "in_qkv_norm_new_weight", + "in_qkv_norm_new_bias", + "in_qkv_weight_1", "in_qkv_scale_1", "in_qkv_offset_1", "in_qkv_descale_1", "in_qkv_bias_1", + "in_qkv_compress_idx_1" + }; + if (param.supportLora) { + if (param.useImMask) { + kInTensor.push_back("in_im_mask"); + } + kInTensor.push_back("in_seq_len_cum_sum"); + kInTensor.push_back("in_qkv_lora_a_1"); + kInTensor.push_back("in_qkv_lora_b_1"); + } + kNormLinearNode.inTensorIds = GetTensorIdxList(tensorMap, kInTensor); + kNormLinearNode.outTensorIds = {GetTensorIdx(tensorMap, "out_k")}; + opGraph.nodes.push_back(kNormLinearNode); + return atb::NO_ERROR; +} + +template +atb::Status AddQKNormNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap) +{ + ATB_SPEED_LOG_DEBUG("QKnorm using aclnn rmsnorm"); + atb::Node qNormNode; + qNormNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_q")); + qNormNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_q_norm_weight")); + qNormNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "out_q")); + qNormNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_q_rstd_out")); + qNormNode.inTensorReshapeFuncs.resize(qNormNode.inTensorIds.size()); + qNormNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 3: 新的shape维度为3 + newShape.dims[0] = oldShape.dims[0]; // 0: bs * seq_len + newShape.dims[1] = oldShape.dims[1] / param.headDim; // 1: 128 q headDim + newShape.dims[2] = param.headDim; // 128: headDim + }; + qNormNode.operation = \ + new atb_speed::common::RmsNormOperation("QRmsNormNode", param.normParamType.normParam.epsilon); + + atb::Node kNormNode; + kNormNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_k")); + kNormNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_k_norm_weight")); + kNormNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "out_k")); + kNormNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_k_rstd_out")); + kNormNode.inTensorReshapeFuncs.resize(kNormNode.inTensorIds.size()); + kNormNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 3: 新的shape维度为3 + newShape.dims[0] = oldShape.dims[0]; // 0: bs * seq_len + newShape.dims[1] = oldShape.dims[1] / param.headDim; // 1: 128 q headDim + newShape.dims[2] = param.headDim; // 128: headDim + }; + kNormNode.operation = \ + new atb_speed::common::RmsNormOperation("KRmsNormNode", param.normParamType.normParam.epsilon); + + opGraph.nodes.push_back(qNormNode); + opGraph.nodes.push_back(kNormNode); + ATB_SPEED_LOG_DEBUG("Add QKnorm to OpGraph."); + return atb::NO_ERROR; +} + +template +atb::Status AddVNormLinearNode(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap, bool isAntiOutlier) +{ + atb::Node vNormLinearNode; + atb_speed::common::NormLinearParam vNormLinearParam; + vNormLinearParam.isAntiOutlier = isAntiOutlier; + vNormLinearParam.fusionLinearParam.quantType = GetLinearQuantType( + param.packQuantType, param.layerLinearQuantType[V_LINEAR_INDEX], param.enableNormQuantOp, + param.layerLinearDescs[V_LINEAR_INDEX]); + vNormLinearParam.fusionLinearParam.isBF16 = param.isBF16; + vNormLinearParam.fusionLinearParam.hasBias = param.qkvHasBias; + vNormLinearParam.fusionLinearParam.supportLora = param.supportLora; + vNormLinearParam.fusionLinearParam.useImMask = param.useImMask; + vNormLinearParam.fusionLinearParam.loraEnableGMM = param.loraEnableGMM; + vNormLinearParam.fusionLinearParam.transposeType = param.layerLinearTransposeType[V_LINEAR_INDEX]; + vNormLinearParam.fusionLinearParam.quantGroupSize = param.quantGroupSize; + vNormLinearParam.fusionLinearParam.isPrefill = param.isPrefill; + vNormLinearParam.skipNorm = param.skipNorm; + vNormLinearParam.normHasBias = param.normHasBias; + vNormLinearParam.normParamType = param.normParamType; + vNormLinearParam.normQuantParamType = param.normQuantParamType; + NormLinear(vNormLinearParam, &vNormLinearNode.operation); + CHECK_OPERATION_STATUS_RETURN(NormLinear(vNormLinearParam, &vNormLinearNode.operation)); + std::vector vInTensor = { + "in_qkv_input", "in_qkv_norm_weight", "in_qkv_norm_bias", "in_qkv_norm_new_weight", + "in_qkv_norm_new_bias", + "in_qkv_weight_2", "in_qkv_scale_2", "in_qkv_offset_2", "in_qkv_descale_2", "in_qkv_bias_2", + "in_qkv_compress_idx_2" + }; + if (param.supportLora) { + if (param.useImMask) { + vInTensor.push_back("in_im_mask"); + } + vInTensor.push_back("in_seq_len_cum_sum"); + vInTensor.push_back("in_qkv_lora_a_2"); + vInTensor.push_back("in_qkv_lora_b_2"); + } + vNormLinearNode.inTensorIds = GetTensorIdxList(tensorMap, vInTensor); + vNormLinearNode.outTensorIds = {GetTensorIdx(tensorMap, "out_v")}; + opGraph.nodes.push_back(vNormLinearNode); + return atb::NO_ERROR; +} + +template +void QKVLinearSplitInferShapeFunc(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, uint32_t inQKVInputIdx, uint32_t inResidualAddInputIdx, uint32_t inFakeAgShapeIdx) +{ + if (param.isFA) { + opGraph.inferShapeFunc = [inQKVInputIdx, inResidualAddInputIdx, param] + (const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(inQKVInputIdx); + outTensorDescs.at(0).shape.dimNum = 4; // 0, 4: shape为4维 + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(inQKVInputIdx).shape.dims[0]; // 0, 0, 0: batch size + outTensorDescs.at(0).shape.dims[1] = inTensorDescs.at(inQKVInputIdx).shape.dims[1]; // 0, 1, 1: seq len + outTensorDescs.at(0).shape.dims[2] = param.selfAttentionParam.headNum; // 0, 2: headNum + outTensorDescs.at(0).shape.dims[3] = param.headDim; // 0, 3: headDim + + outTensorDescs.at(1) = outTensorDescs.at(0); + outTensorDescs.at(1).shape.dims[2] = param.selfAttentionParam.kvHeadNum; // 0, 2: kvHeadNum + + outTensorDescs.at(2) = outTensorDescs.at(1); // 2: 第2个输出tensor的描述和第1个输出tensor的描述一致 + if (param.enableAddNorm) { + outTensorDescs.at(3) = inTensorDescs.at(inResidualAddInputIdx); // 3: AddNorm融合有第3个输出 + } + return atb::NO_ERROR; + }; + } else { + opGraph.inferShapeFunc = [inQKVInputIdx, inResidualAddInputIdx, param, inFakeAgShapeIdx] + (const atb::SVector &inTensorDescs, atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(inQKVInputIdx); + outTensorDescs.at(0).shape.dimNum = 3; // 0, 3: shape为3维 + if (param.enableFlashComm) { + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(inFakeAgShapeIdx).shape.dims[0]; + } else { + outTensorDescs.at(0).shape.dims[0] = \ + inTensorDescs.at(inQKVInputIdx).shape.dims[0]; // 0, 0, 0: batch size * seq len + } + outTensorDescs.at(0).shape.dims[1] = param.selfAttentionParam.headNum; // 0, 1: headNum + outTensorDescs.at(0).shape.dims[2] = param.headDim; // 0, 2: headDim + + outTensorDescs.at(1) = outTensorDescs.at(0); + outTensorDescs.at(1).shape.dims[1] = param.selfAttentionParam.kvHeadNum; // 0, 1: kvHeadNum + + outTensorDescs.at(2) = outTensorDescs.at(1); // 2: 第2个输出tensor的描述和第1个输出tensor的描述一致 + if (param.enableAddNorm) { + outTensorDescs.at(3) = inTensorDescs.at(inResidualAddInputIdx); // 3: AddNorm融合有第3个输出 + } + return atb::NO_ERROR; + }; + } +} + +template +atb::Status QKVLinearSplit(const FusionAttentionParam ¶m, atb::Operation **operation) +{ + if (param.layerLinearDescs.size() != 0 && \ + CheckParamVectorSize(param.layerLinearDescs, V_LINEAR_INDEX + 1) != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR("The size of param.layerLinearDescs is wrong, please check"); + return atb::ERROR_INVALID_PARAM; + } + + std::vector qkvLinearIndex = {Q_LINEAR_INDEX, K_LINEAR_INDEX, V_LINEAR_INDEX}; + bool isPack = CheckPack(param.packQuantType, param.layerLinearDescs, qkvLinearIndex); + bool isAntiOutlier = CheckAntiOutlier(param.packQuantType); + isAntiOutlier = isAntiOutlier || param.isAntiOutlier; + + atb::GraphParam opGraph; + opGraph.name = isPack ? "QKVLinearSplitPack" : "QKVLinearSplitNoPack"; + std::map tensorMap = ConstructQKVTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + ATB_SPEED_LOG_DEBUG("qkv opGraph.inTensorNum " << opGraph.inTensorNum); + ATB_SPEED_LOG_DEBUG("qkv opGraph.outTensorNum " << opGraph.outTensorNum); + ATB_SPEED_LOG_DEBUG("qkv opGraph.internalTensorNum " << opGraph.internalTensorNum); + + CHECK_PARAM_GT(param.selfAttentionParam.kvHeadNum, 0); + CHECK_PARAM_GT(param.selfAttentionParam.headNum, 0); + CHECK_PARAM_GT(param.headDim, 0); + CHECK_PARAM_LT(param.headDim, 576); // 576: headDim上界 + + CHECK_OPERATION_STATUS_RETURN(AddQNormLinearNode(param, opGraph, tensorMap, isAntiOutlier, isPack)); + + if (!param.enableRopeQuantKvcache) { + if (isPack && param.isGroupedQueryAttention) { // Split GQA + CHECK_OPERATION_STATUS_RETURN(AddSplitQKVNode(param, opGraph, tensorMap)); + if (param.useQKNorm) { + CHECK_OPERATION_STATUS_RETURN(AddQKNormNode(param, opGraph, tensorMap)); + } + } else if (isPack && !param.isGroupedQueryAttention) { // Split MHA + CHECK_OPERATION_STATUS_RETURN(AddSplitMixedQKVNode(param, opGraph, tensorMap)); + } else { // isPack: false + if (param.layerLinearQuantType.size() != 0 && \ + CheckParamVectorSize(param.layerLinearQuantType, V_LINEAR_INDEX + 1) != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR("The size of param.layerLinearQuantType is wrong, please check"); + return atb::ERROR_INVALID_PARAM; + } + if (CheckParamVectorSize(param.layerLinearTransposeType, V_LINEAR_INDEX + 1) != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR("The size of param.layerLinearTransposeType is wrong, please check"); + return atb::ERROR_INVALID_PARAM; + } + CHECK_OPERATION_STATUS_RETURN(AddKNormLinearNode(param, opGraph, tensorMap, isAntiOutlier)); + CHECK_OPERATION_STATUS_RETURN(AddVNormLinearNode(param, opGraph, tensorMap, isAntiOutlier)); + } + + uint32_t inQKVInputIdx = GetTensorIdx(tensorMap, "in_qkv_input"); + uint32_t inResidualAddInputIdx = GetTensorIdx(tensorMap, "in_residual_add"); + uint32_t inFakeAgShapeIdx = GetTensorIdx(tensorMap, "fake_ag_shape"); + QKVLinearSplitInferShapeFunc(param, opGraph, inQKVInputIdx, inResidualAddInputIdx, inFakeAgShapeIdx); + } + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + return atb::NO_ERROR; +} + +template atb::Status AddQNormLinearNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap, bool isAntiOutlier, bool isPack); +template atb::Status AddQNormLinearNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap, bool isAntiOutlier, bool isPack); + +template atb::Status AddSplitQKVNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap); +template atb::Status AddSplitQKVNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap); + +template atb::Status AddSplitMixedQKVNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap); +template atb::Status AddSplitMixedQKVNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap); + +template atb::Status AddKNormLinearNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap, bool isAntiOutlier); +template atb::Status AddKNormLinearNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap, bool isAntiOutlier); + +template atb::Status AddVNormLinearNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap, bool isAntiOutlier); +template atb::Status AddVNormLinearNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap, bool isAntiOutlier); + +template atb::Status AddQKNormNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap); +template atb::Status AddQKNormNode( + const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap); + +template std::map ConstructQKVTensorMap( + const FusionAttentionParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum); +template std::map ConstructQKVTensorMap( + const FusionAttentionParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum); + +template atb::Status QKVLinearSplit( + const FusionAttentionParam ¶m, + atb::Operation **operation); +template atb::Status QKVLinearSplit( + const FusionAttentionParam ¶m, + atb::Operation **operation); + +template void QKVLinearSplitInferShapeFunc(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, uint32_t inQKVInputIdx, uint32_t inResidualAddInputIdx, uint32_t inFakeAgShapeIdx); +template void QKVLinearSplitInferShapeFunc(const FusionAttentionParam ¶m, + atb::GraphParam &opGraph, uint32_t inQKVInputIdx, uint32_t inResidualAddInputIdx, uint32_t inFakeAgShapeIdx); +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/attention/qkv_linear_split.h b/tests/proftest/layer_test_framework/operations/fusion/attention/qkv_linear_split.h new file mode 100644 index 00000000..567f0435 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/attention/qkv_linear_split.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_MODELS_COMMON_QKV_H +#define ATB_SPEED_MODELS_COMMON_QKV_H + +#include +#include "operations/fusion/attention/fusion_attention.h" + + +namespace atb_speed { +namespace common { + +/// This function performs normalization and qkv linear operations. +/// It supports grouped query attention, multi head attention and quantization scenarios. +/// It also accepts packed qkv linear weights. +/// +/// \tparam NormParamType Types of the normalization parameters. Avaliable options are `atb::infer::RmsNormParam` +/// and `atb::infer::LayerNormParam`. +/// \param param Parameters for the FusionAttention module +/// \param operation the address of a pointer to a default operation +/// \return A flag indicating whether the operation has been successfully created. +/// +/// Operation's inputs: +/// Name | Requirements | Dtype | Shape | Description | +/// -----------------------|--------------|------------------|-------|----------| +/// in_qkv_input | Required | float16/bfloat16 | `isFA` is false: [len(all_seq),num_heads,head_dim] | Hidden States | +/// ^ | ^ | ^ | `isFA` is true: [bsz,seq_len,num_heads,head_dim] | ^ | +/// in_qkv_norm_weight | ^ | Refer to `atb_speed::common::NormLinear` in the `operations/fusion/norm/norm_linear.h` for more details. ||| +/// in_qkv_norm_bias | ^ | ^ ||| +/// in_qkv_norm_new_weight | ^ | ^ ||| +/// in_qkv_norm_new_bias | ^ | ^ ||| +/// in_qkv_weight_0 | ^ | If qkv are packed, these are concatenated qkv linear weights; otherwise, these are weights for the q linear operation.
Refer to `atb_speed::common::FusionLinear` in the `operations/fusion/linear/linear.h` for more details. ||| +/// in_qkv_scale_0 | ^ | ^ ||| +/// in_qkv_offset_0 | ^ | ^ ||| +/// in_qkv_descale_0 | ^ | ^ ||| +/// in_qkv_bias_0 | ^ | ^ ||| +/// in_qkv_compress_idx_0 | ^ | ^ ||| +/// in_qkv_weight_1 | ^ | If qkv are not packed, these are weights for the k linear operation; otherwise, placeholders should be provided.
Refer to `atb_speed::common::FusionLinear` in the `operations/fusion/linear/linear.h` for more details. ||| +/// in_qkv_scale_1 | ^ | ^ ||| +/// in_qkv_offset_1 | ^ | ^ ||| +/// in_qkv_descale_1 | ^ | ^ ||| +/// in_qkv_bias_1 | ^ | ^ ||| +/// in_qkv_compress_idx_1 | ^ | ^ ||| +/// in_qkv_weight_2 | ^ | If qkv are not packed, these are weights for the v linear operation; otherwise, placeholders should be provided.
Refer to `atb_speed::common::FusionLinear` in the `operations/fusion/linear/linear.h` for more details. ||| +/// in_qkv_scale_2 | ^ | ^ ||| +/// in_qkv_offset_2 | ^ | ^ ||| +/// in_qkv_descale_2 | ^ | ^ ||| +/// in_qkv_bias_2 | ^ | ^ ||| +/// in_qkv_compress_idx_2 | ^ | ^ ||| +/// in_im_mask | `param.supportLora` is true and `param.useImMask` is true | Refer to `atb_speed::common::FusionLinear` in the `operations/fusion/linear/linear.h` for more details. ||| +/// in_seq_len_cum_sum | `param.supportLora` is true | ^ ||| +/// in_qkv_lora_a_0 | ^ | ^ ||| +/// in_qkv_lora_b_0 | ^ | ^ ||| +/// in_qkv_lora_a_1 | ^ | ^ ||| +/// in_qkv_lora_b_1 | ^ | ^ ||| +/// in_qkv_lora_a_2 | ^ | ^ ||| +/// in_qkv_lora_b_2 | ^ | ^ ||| +/// in_q_norm_weight | `param.useQKNorm` is true | | | | +/// in_k_norm_weight | `param.useQKNorm` is true | | | | +/// +/// Operations's outputs: +/// Name | Dtype | Shape | +/// -----------|--------------------------|------------------------------------------------------| +/// out_q | The same as in_qkv_input | `isFA` is false: [len(all_seq),head_num,head_dim] | +/// ^ | ^ | `isFA` is true: [bsz,seq_len,head_num,head_dim] | +/// out_k | The same as in_qkv_input | `isFA` is false: [len(all_seq),kv_head_num,head_dim] | +/// ^ | ^ | `isFA` is true: [bsz,seq_len,kv_head_num,head_dim] | +/// out_v | The same as in_qkv_input | `isFA` is false: [len(all_seq),kv_head_num,head_dim] | +/// ^ | ^ | `isFA` is true: [bsz,seq_len,kv_head_num,head_dim] | +/// +/// Example: +/// \code +/// atb::Node qkvLinearSplitNode; +/// atb_speed::common::FusionAttentionParam fusionAttentionParam; +/// // Modify fusionAttentionParam's attribute if needed. +/// QKVLinearSplit(fusionAttentionParam, &qkvLinearSplitNode.operation); +/// qkvLinearSplitNode.inTensorIds = {...}; // Passing inputs for the operation in order +/// qkvLinearSplitNode.outTensorIds = {...}; // Tensor index for out_q, out_k, out_v +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(qkvLinearSplitNode); +/// \endcode +template +atb::Status QKVLinearSplit(const FusionAttentionParam ¶m, atb::Operation **operation); +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/attention/self_attention.cpp b/tests/proftest/layer_test_framework/operations/fusion/attention/self_attention.cpp new file mode 100644 index 00000000..0320c18f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/attention/self_attention.cpp @@ -0,0 +1,364 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "atb_speed/utils/check_util.h" +#include "operations/aclnn/ops/dequant_rope_quant_kvcache_operation.h" +#include "operations/aclnn/ops/attn_operation.h" +#include "operations/fusion/utils.h" +#include "operations/fusion/infer_shape_functions.h" +#include "operations/aclnn/ops/attn_operation.h" +#include "operations/fusion/attention/self_attention.h" + +namespace atb_speed { +namespace common { + +template +int64_t AddSelfAttention( + atb::GraphParam& opGraph, const FusionAttentionParam& param, + std::map& tensorMap) +{ + if (!param.enableRopeQuantKvcache) { + if (!param.isFA) { // PA + CHECK_OPERATION_STATUS_RETURN(AddPaKVCacheOperation(opGraph, param, tensorMap)); + } else if (param.isFA && param.attnBackend == atb_speed::common::OpBackend::ACLNN) { // ACLNN FA + CHECK_OPERATION_STATUS_RETURN(AddFaKVCacheOperation(opGraph, param, tensorMap)); + } + } + // SelfAttentionNode + atb::Node selfAttentionNode; + if (param.isFA) { // FA + if (param.attnBackend == atb_speed::common::OpBackend::ACLNN && param.isPrefill) { // ACLNN FA Encode + CHECK_OPERATION_STATUS_RETURN(ConstructPaEncoderNode(selfAttentionNode, param, tensorMap)); + } else if (param.attnBackend == atb_speed::common::OpBackend::ACLNN && !param.isPrefill) { // ACLNN FA Decode + CHECK_OPERATION_STATUS_RETURN(ConstructAclNNDecoderNode(selfAttentionNode, param, tensorMap)); + } else { // ATB FA + CHECK_OPERATION_STATUS_RETURN(ConstructFaNode(selfAttentionNode, param, tensorMap)); + } + } else { + if (param.isPrefill && !param.enableSplitFuse) { // PA Prefill + CHECK_OPERATION_STATUS_RETURN(ConstructPaEncoderNode(selfAttentionNode, param, tensorMap)); + } else if (param.attnBackend == atb_speed::common::OpBackend::ATB) { // ATB PA Decode + CHECK_OPERATION_STATUS_RETURN(ConstructPaDecoderNode(selfAttentionNode, param, tensorMap)); + } else if (param.attnBackend == atb_speed::common::OpBackend::ACLNN) { // ACLNN PA Decode + CHECK_OPERATION_STATUS_RETURN(ConstructAclNNDecoderNode(selfAttentionNode, param, tensorMap)); + } + } + + selfAttentionNode.outTensorIds = { GetTensorIdx(tensorMap, "intermediate_self_attention") }; + opGraph.nodes.push_back(selfAttentionNode); + + return atb::NO_ERROR; +} + +template +int64_t AddFaKVCacheOperation( + atb::GraphParam& opGraph, const FusionAttentionParam& param, + std::map& tensorMap) +{ + // moveKCache Node + atb::infer::KvCacheParam kvCacheParam; + atb::Node moveKCacheNode; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(kvCacheParam, &moveKCacheNode.operation)); + moveKCacheNode.inTensorIds = { + param.aclnnIncreAttentionParam.hasKVQuant ? \ + GetTensorIdx(tensorMap, "intermediate_k_int8") : GetTensorIdx(tensorMap, "intermediate_k"), + GetTensorIdx(tensorMap, "in_layer_id"), + GetTensorIdx(tensorMap, "in_k_cache"), + GetTensorIdx(tensorMap, "in_token_offset"), + GetTensorIdx(tensorMap, "in_seq_len"), + }; + moveKCacheNode.inTensorReshapeFuncs.resize(moveKCacheNode.inTensorIds.size()); + moveKCacheNode.inTensorReshapeFuncs.at(0) = // 0: [B,S,N,D]=>[BS,ND] + &SqueezeBatchAndHiddenSize; + moveKCacheNode.inTensorReshapeFuncs.at(2) = [=]( // 2: [B,S,ND]=>[1,B,S,ND] + const atb::Dims& oldShape, atb::Dims& newShape) { + UnsqueezeAxis(oldShape, newShape, 0); + }; + moveKCacheNode.outTensorIds = {}; + opGraph.nodes.push_back(moveKCacheNode); + + atb::Node moveVCacheNode; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(kvCacheParam, &moveVCacheNode.operation)); + moveVCacheNode.inTensorIds = { + param.aclnnIncreAttentionParam.hasKVQuant ? \ + GetTensorIdx(tensorMap, "intermediate_v_int8") : GetTensorIdx(tensorMap, "intermediate_v"), + GetTensorIdx(tensorMap, "in_layer_id"), + GetTensorIdx(tensorMap, "in_v_cache"), + GetTensorIdx(tensorMap, "in_token_offset"), + GetTensorIdx(tensorMap, "in_seq_len"), + }; + moveVCacheNode.inTensorReshapeFuncs.resize(moveVCacheNode.inTensorIds.size()); + moveVCacheNode.inTensorReshapeFuncs.at(0) = // 0: [B,S,N,D]=>[BS,ND] + &SqueezeBatchAndHiddenSize; + moveVCacheNode.inTensorReshapeFuncs.at(2) = [=]( // 2: [B,S,ND]=>[1,B,S,ND] + const atb::Dims& oldShape, atb::Dims& newShape) { + UnsqueezeAxis(oldShape, newShape, 0); + }; + moveVCacheNode.outTensorIds = {}; + opGraph.nodes.push_back(moveVCacheNode); + return atb::NO_ERROR; +} + +template +int64_t AddPaKVCacheOperation( + atb::GraphParam& opGraph, const FusionAttentionParam& param, + std::map& tensorMap) +{ + // ReshapeAndCache Node + atb::Node reshapeAndCacheNode; + if (param.enableOmniattention && param.isomnicompressed) { + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(param.reshapeCacheOmniParm, + &reshapeAndCacheNode.operation)); + } else { + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(param.reshapeCacheParm, &reshapeAndCacheNode.operation)); + } + reshapeAndCacheNode.inTensorIds = { + param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION || \ + param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE ? \ + GetTensorIdx(tensorMap, "intermediate_k_int8") : GetTensorIdx(tensorMap, "intermediate_k"), + param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION || \ + param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE ? \ + GetTensorIdx(tensorMap, "intermediate_v_int8") : GetTensorIdx(tensorMap, "intermediate_v"), + GetTensorIdx(tensorMap, "in_k_cache"), + GetTensorIdx(tensorMap, "in_v_cache"), + GetTensorIdx(tensorMap, "in_slots_in_pa_or_logn_in_fa"), + }; + if (param.reshapeCacheParm.compressType == \ + atb::infer::ReshapeAndCacheParam::CompressType::COMPRESS_TYPE_KVHEAD + ) { + reshapeAndCacheNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_batch_wins")); + reshapeAndCacheNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_seq_len")); + } else if (param.reshapeCacheParm.compressType == \ + atb::infer::ReshapeAndCacheParam::CompressType::COMPRESS_TYPE_KVHEAD_ROPE || + param.isomnicompressed) { + reshapeAndCacheNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_batch_wins")); + reshapeAndCacheNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_reshape_seq_len")); + reshapeAndCacheNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_pffset_index")); + } + reshapeAndCacheNode.outTensorIds = { + GetTensorIdx(tensorMap, "in_k_cache"), + GetTensorIdx(tensorMap, "in_v_cache"), + }; + opGraph.nodes.push_back(reshapeAndCacheNode); + return atb::NO_ERROR; +} + +template +int64_t ConstructAclNNDecoderNode( + atb::Node& selfAttentionNode, const FusionAttentionParam& param, + std::map& tensorMap) +{ + // 输入FA QKV [B,S,H] PA Q [B,S,N,D] KV [num_blocks,block_size,ND] + // 输出FA [B,S,H] PA [BS,N,D] + selfAttentionNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_q"), GetTensorIdx(tensorMap, "in_k_cache"), + GetTensorIdx(tensorMap, "in_v_cache"), GetTensorIdx(tensorMap, "in_attention_mask") + }; + if (param.aclnnIncreAttentionParam.isFA) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_token_offset")); + } else { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_seq_len")); + } + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_block_tables")); + if (param.aclnnIncreAttentionParam.hasKVQuant) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_k_dequant_scale")); + if (param.aclnnIncreAttentionParam.hasQuantOffset) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_k_dequant_offset")); + } + } + selfAttentionNode.operation = new atb_speed::common::AttnOperation( + "AclNNAttentionNode", param.aclnnIncreAttentionParam); + selfAttentionNode.inTensorReshapeFuncs.resize(selfAttentionNode.inTensorIds.size()); + if (param.isFA) { + selfAttentionNode.inTensorReshapeFuncs.at(0) = &SqueezeHeadNumHeadDim; + } else { + selfAttentionNode.inTensorReshapeFuncs.at(0) = [=](const atb::Dims& oldShape, atb::Dims& newShape) { + UnsqueezeAxis(oldShape, newShape, 1); + }; + selfAttentionNode.inTensorReshapeFuncs.at(1) = &SqueezeHeadNumHeadDim; // 1: in_k_cache + selfAttentionNode.inTensorReshapeFuncs.at(2) = &SqueezeHeadNumHeadDim; // 2: in_v_cache + } + return atb::NO_ERROR; +} + +template +int64_t ConstructFaNode( + atb::Node& selfAttentionNode, const FusionAttentionParam& param, + std::map& tensorMap) +{ + // 输入[nTokens, vHiddenSize] 输出[nTokens, vHiddenSize] + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(param.selfAttentionParam, &selfAttentionNode.operation)); + selfAttentionNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_q"), + GetTensorIdx(tensorMap, "intermediate_k"), + GetTensorIdx(tensorMap, "intermediate_v"), + GetTensorIdx(tensorMap, "in_k_cache"), + GetTensorIdx(tensorMap, "in_v_cache"), + }; + if (param.selfAttentionParam.maskType != atb::infer::SelfAttentionParam::MASK_TYPE_UNDEFINED) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_attention_mask")); + } + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_token_offset")); + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_seq_len")); + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_layer_id")); + if (param.selfAttentionParam.maskType == atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS || \ + param.selfAttentionParam.maskType == atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS_SQRT || \ + param.selfAttentionParam.maskType == atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN + ) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_slopes")); + } + selfAttentionNode.inTensorReshapeFuncs.resize(selfAttentionNode.inTensorIds.size()); + selfAttentionNode.inTensorReshapeFuncs.at(0) = &SqueezeBatchAndHiddenSize; // 0: [B,S,N,D]=>[BS,ND] + selfAttentionNode.inTensorReshapeFuncs.at(1) = &SqueezeBatchAndHiddenSize; // 1: [B,S,N,D]=>[BS,ND] + selfAttentionNode.inTensorReshapeFuncs.at(2) = &SqueezeBatchAndHiddenSize; // 2: [B,S,N,D]=>[BS,ND] + selfAttentionNode.inTensorReshapeFuncs.at(3) = [=]( // 3: [BS,N,D]=>[1,BS,N,D] + const atb::Dims& oldShape, atb::Dims& newShape) { + UnsqueezeAxis(oldShape, newShape, 0); + }; + selfAttentionNode.inTensorReshapeFuncs.at(4) = [=]( // 4: [BS,N,D]=>[1,BS,N,D] + const atb::Dims& oldShape, atb::Dims& newShape) { + UnsqueezeAxis(oldShape, newShape, 0); + }; + return atb::NO_ERROR; +} + +template +int64_t ConstructPaEncoderNode( + atb::Node& selfAttentionNode, const FusionAttentionParam& param, + std::map& tensorMap) +{ + // 输入[BS, N, D] 输出[BS, N, D] + ATB_SPEED_LOG_DEBUG("Enter ConstructPaEncoderNode"); + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(param.selfAttentionParam, &selfAttentionNode.operation)); + selfAttentionNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_q"), + GetTensorIdx(tensorMap, "intermediate_k"), + GetTensorIdx(tensorMap, "intermediate_v"), + }; + if (param.selfAttentionParam.maskType != atb::infer::SelfAttentionParam::MASK_TYPE_UNDEFINED) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_attention_mask")); + } + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_seq_len")); + if ( + param.selfAttentionParam.maskType == atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS || \ + param.selfAttentionParam.maskType == atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS_SQRT || \ + param.selfAttentionParam.maskType == atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN + ) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_slopes")); + } + selfAttentionNode.inTensorReshapeFuncs.resize(selfAttentionNode.inTensorIds.size()); + if (param.attnBackend == atb_speed::common::OpBackend::ACLNN) { + selfAttentionNode.inTensorReshapeFuncs.at(0) = &SqueezeBatchAndHiddenSize; // 0: [B,S,N,D]=>[BS,ND] + selfAttentionNode.inTensorReshapeFuncs.at(1) = &SqueezeBatchAndHiddenSize; // 1: [B,S,N,D]=>[BS,ND] + selfAttentionNode.inTensorReshapeFuncs.at(2) = &SqueezeBatchAndHiddenSize; // 2: [B,S,N,D]=>[BS,ND] + } + + return atb::NO_ERROR; +} + +template +int64_t ConstructPaDecoderNode( + atb::Node& selfAttentionNode, const FusionAttentionParam& param, + std::map& tensorMap) +{ + // 输出[num_tokens, N, D] [num_block,block_size,N,D] + // 输出[num_tokens, num_head, head_size] + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(param.pageAttentionParam, &selfAttentionNode.operation)); + selfAttentionNode.inTensorIds = { + param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE ? \ + GetTensorIdx(tensorMap, "intermediate_q_int8") : GetTensorIdx(tensorMap, "intermediate_q"), + GetTensorIdx(tensorMap, "in_k_cache"), + GetTensorIdx(tensorMap, "in_v_cache"), + GetTensorIdx(tensorMap, "in_block_tables"), + }; + if (param.pageAttentionParam.compressType == atb::infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD || + param.pageAttentionParam.compressType == + atb::infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD_ROPE) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_ra_seq_len")); + } else { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_seq_len")); + } + if (param.pageAttentionParam.maskType != atb::infer::PagedAttentionParam::MaskType::UNDEFINED) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_attention_mask")); + } + if (param.pageAttentionParam.calcType == atb::infer::PagedAttentionParam::CalcType::CALC_TYPE_SPEC) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_q_len")); + } + if (param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_k_dequant_scale")); + if (param.pageAttentionParam.hasQuantOffset) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_k_dequant_offset")); + } + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_v_dequant_scale")); + if (param.pageAttentionParam.hasQuantOffset) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_v_dequant_offset")); + } + } + if (param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_qk_descale")); + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "fa3_v_quant_scale")); + } + if (param.pageAttentionParam.compressType == + atb::infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD_ROPE) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_ra_offset")); + } + if (param.pageAttentionParam.scaleType == atb::infer::PagedAttentionParam::ScaleType::SCALE_TYPE_LOGN) { + selfAttentionNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_log_n_scale")); + } + return atb::NO_ERROR; +} + + +template int64_t AddFaKVCacheOperation( + atb::GraphParam& opGraph, const FusionAttentionParam& param, + std::map& tensorMap); +template int64_t AddPaKVCacheOperation( + atb::GraphParam& opGraph, const FusionAttentionParam& param, + std::map& tensorMap); +template int64_t AddSelfAttention( + atb::GraphParam& opGraph, const FusionAttentionParam& param, + std::map& tensorMap); +template int64_t ConstructFaNode( + atb::Node& selfAttentionNode, const FusionAttentionParam& param, + std::map& tensorMap); +template int64_t ConstructPaEncoderNode( + atb::Node& selfAttentionNode, const FusionAttentionParam& param, + std::map& tensorMap); +template int64_t ConstructPaDecoderNode( + atb::Node& selfAttentionNode, const FusionAttentionParam& param, + std::map& tensorMap); + +template int64_t AddFaKVCacheOperation( + atb::GraphParam& opGraph, const FusionAttentionParam& param, + std::map& tensorMap); +template int64_t AddPaKVCacheOperation( + atb::GraphParam& opGraph, const FusionAttentionParam& param, + std::map& tensorMap); +template int64_t AddSelfAttention( + atb::GraphParam& opGraph, const FusionAttentionParam& param, + std::map& tensorMap); +template int64_t ConstructFaNode( + atb::Node& selfAttentionNode, const FusionAttentionParam& param, + std::map& tensorMap); +template int64_t ConstructPaEncoderNode( + atb::Node& selfAttentionNode, const FusionAttentionParam& param, + std::map& tensorMap); +template int64_t ConstructPaDecoderNode( + atb::Node& selfAttentionNode, const FusionAttentionParam& param, + std::map& tensorMap); + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/attention/self_attention.h b/tests/proftest/layer_test_framework/operations/fusion/attention/self_attention.h new file mode 100644 index 00000000..f89542ba --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/attention/self_attention.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_MODELS_COMMON_SELF_ATTENTION_H +#define ATB_SPEED_MODELS_COMMON_SELF_ATTENTION_H + +#include +#include +#include "operations/fusion/attention/fusion_attention.h" + +namespace atb_speed { +namespace common { + +/// This function adds kv cache movement operation and attention operations to the graph. +/// It supports flash attention and paged attention. +/// It supports ATB backend and AclNN backend. +/// +/// \tparam NormParamType Types of the normalization parameters. Avaliable options are `atb::infer::RmsNormParam` +/// and `atb::infer::LayerNormParam`. +/// \param opGraph A reference to the graph +/// \param param Parameters for the FusionAttention module +/// \param operation the address of a pointer to a default operation +/// \param tensorMap Defines all the required tensors for the current graph, with the key representing +/// the input tensor name and the value corresponding to the tensor index. +/// Tensors are ordered by input tensors, output tensors and internal tensors. +/// \return A flag indicating whether the operation has been successfully created. +/// +/// This function will use the following tensors: +/// Key in `tensorMap` | Requirements | Dtype | Shape | Description | +/// -------------------|--------------|------------------|-------|----------| +/// in_seq_len | Required | int32 | [batch_size] | The total number of input and output tokens.
In the prefill phase, each elements equals to the length of the prompt.
For flash attention, each element is set to 1 in the decode phase.
For paged attention, each element is set to the number of input tokens plus output tokens in the decode phase. | +/// in_k_cache | ^ | float16/bfloat16/int8 | [num_block,block_size,head_num,head_dim] | | +/// in_v_cache | ^ | ^ | ^ | | +/// in_attention_mask | ^ | Refer to SelfAttetion/PagedAttention Operation in the `atb/infer_op_params.h` and AttnOperation in the `operations/aclnn/ops/attn_operation.h` for more details. ||| +/// in_token_offset | ^ | int32 | [batch] | Token offset after calculation. Used only if `isFA` is true. | +/// in_layer_id | ^ | int32 | [1] | The index of kv cache for the current layer. Used only if `isFA` is true. | +/// in_block_tables | ^ | int32 | [num_tokens, max_num_blocks_per_query] | Used only if `isFA` is false. | +/// in_slots_in_pa_or_logn_in_fa | ^ | float32 | `isFA` is true (Prefill phase): [maxSeqLen] | Logarithmic scaling. | +/// ^ | ^ | ^ | `isFA` is true (Decode phase): [batch_size] | ^ | +/// ^ | ^ | ^ | `isFA` is false: [num_tokens] | Storage offset of each token key or value in the cache. | +/// in_slopes | `param.selfAttentionParam.maskType` is in one of `atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS`, `atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS_SQRT` and `atb::infer::SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN` | Atlas 800I A2: float32; Atlas 300I DUO: float16 | [head_num] | It is the coefficient of each head of the alibi mask. | +/// in_batch_wins | `param.pageAttentionParam.compressType` equals to `atb::infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD` or `atb::infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD_ROPE` | int32 | [batch*head_num] | Compressed window size | +/// in_ra_seq_len | ^ | int32 | [batch_size] | | +/// in_pffset_index | `param.pageAttentionParam.compressType == atb::infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD_ROPE` | int32 | [batch*head_num] | | +/// in_ra_offset | ^ | float32 | [num_blocks,block_size] | | +/// in_reshape_seq_len | ^ | int32 | [batch] | | +/// in_q_len | `param.pageAttentionParam.calcType == atb::infer::PagedAttentionParam::CalcType::CALC_TYPE_SPEC` | int32 | [batch] | Number of input tokens for the current forward pass. | +/// in_k_dequant_scale | `param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION` | Refer to the table below for further details. ||| +/// in_v_dequant_scale | ^ | ^ ||| +/// in_k_dequant_offset| `param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION` and `param.pageAttentionParam.hasQuantOffset` is true | Refer to the table below for further details. ||| +/// in_v_dequant_offset| ^ | ^ ||| +/// in_log_n_scale | `param.pageAttentionParam.scaleType == atb::infer::PagedAttentionParam::ScaleType::SCALE_TYPE_LOGN` | float32 | [batch] | logarithmic scaling | +/// in_qk_descale | `param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE` | float32 | [head_num] | | +/// fa3_v_quant_scale | ^ | float32 | [head_num] | | +/// intermediate_q | Required | float16/bfloat16 | `isFA` is true: [bsz,seq_len,head_num,head_dim] | The output of the q linear operation. | +/// ^ | ^ | ^ | `isFA` is false: [len(all_seq),head_num,head_dim] | ^ | +/// intermediate_k | ^ | ^ | `isFA` is true: [bsz,seq_len,kv_head_num,head_dim] | The output of the k linear operation. | +/// ^ | ^ | ^ | `isFA` is false: [len(all_seq),kv_head_num,head_dim] | ^ | +/// intermediate_v | ^ | ^ | `isFA` is true: [bsz,seq_len,kv_head_num,head_dim] | The output of the v linear operation. | +/// ^ | ^ | ^ | `isFA` is false: [len(all_seq),kv_head_num,head_dim] | ^ | +/// intermediate_self_attention | ^ | ^ | `isFA` is true: [bsz,seq_len,head_num*head_dim] | The output of the attention operation. | +/// ^ | ^ | ^ | `isFA` is false: [len(all_seq),head_num*head_dim] | ^ | +/// intermediate_k_int8 | `param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION` | int8 | The same as intermediate_k | intermediate_k after int8 quantization | +/// intermediate_v_int8 | ^ | int8 | The same as intermediate_k | intermediate_v after int8 quantization | +/// intermediate_q_int8 | `param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION` | int8 | The same as intermediate_k | intermediate_q after int8 quantization | +/// ^ | `!param.isPrefill && param.pageAttentionParam.quantType == atb::infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE` | ^ | ^ | ^ | +/// +/// Detailed information for the in_k_dequant_scale, in_v_dequant_scale, in_k_dequant_offset and in_v_dequant_offset. +/// Key in `tensorMap` | Dtype of the output tenosr | Attention Backend | Dtype | Shape | Description | +/// -------------------|----------------------------|-------------------|-------|-------|----------| +/// in_k_dequant_scale | float16 | ATB | int64 | [head_num*head_dim] | | +/// in_v_dequant_scale | ^ | ^ | ^ | ^ | | +/// in_k_dequant_offset| ^ | ^ | int32 | ^ | | +/// in_v_dequant_offset| ^ | ^ | ^ | ^ | | +/// in_k_dequant_scale | bfloat16 | ATB | float32 | ^ | | +/// in_v_dequant_scale | ^ | ^ | ^ | ^ | | +/// in_k_dequant_offset| ^ | ^ | int32 | ^ | | +/// in_v_dequant_offset| ^ | ^ | ^ | ^ | | +/// in_k_dequant_scale | float16 | AclNN | float16 | [2,head_num*head_dim] | | +/// in_v_dequant_scale | ^ | ^ | ^ | [1] | placeholder| +/// in_k_dequant_offset| ^ | ^ | ^ | [2,head_num*head_dim] | placeholder | +/// in_v_dequant_offset| ^ | ^ | ^ | [1] | placeholder | +/// in_k_dequant_scale | bfloat16 | AclNN | bfloat16 | [2,head_num*head_dim] | placeholder | +/// in_v_dequant_scale | ^ | ^ | ^ | [1] | placeholder | +/// in_k_dequant_offset| ^ | ^ | ^ | [2,head_num*head_dim] | placeholder | +/// in_v_dequant_offset| ^ | ^ | ^ | [1] | placeholder | +/// +/// Example: +/// \code +/// atb_speed::common::FusionAttentionParam fusionAttentionParam; +/// // Modify fusionAttentionParam's attribute if needed. +/// // Define all the required tensors and corresponding tensor index. +/// std::map tensorMap = {{"in_k_cache", 0}, {"in_v_cache", 1}, ...} +/// atb::GraphParam opGraph; +/// AddSelfAttention(opGraph, fusionAttentionParam, tensorMap); +/// \endcode +template +int64_t AddSelfAttention( + atb::GraphParam &opGraph, const FusionAttentionParam ¶m, + std::map &tensorMap); + +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/common_op_base.h b/tests/proftest/layer_test_framework/operations/fusion/common_op_base.h new file mode 100644 index 00000000..200849b9 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/common_op_base.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_LAYERS_COMMON_H +#define ATB_SPEED_LAYERS_COMMON_H + + +namespace atb_speed { +namespace common { +class CommonOpBase { +public: + int inTensorNum; + int outTensorNum; + int interTensorNum; + int nodeCount; + + CommonOpBase(int a, int b, int c, int d) : inTensorNum(a), outTensorNum(b), interTensorNum(c), nodeCount(d) {} +}; +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/fusion/embedding/positional_embedding.cpp b/tests/proftest/layer_test_framework/operations/fusion/embedding/positional_embedding.cpp new file mode 100644 index 00000000..0aca600f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/embedding/positional_embedding.cpp @@ -0,0 +1,298 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "operations/fusion/infer_shape_functions.h" +#include "atb_speed/utils/check_util.h" +#include "operations/fusion/embedding/positional_embedding.h" + +namespace atb_speed { +namespace common { + +enum PositionalEmbeddingGatherTensorIdx : uint32_t { + IN_POSITION_IDS = 0, + IN_COS_TABLE, + IN_SIN_TABLE, + OUT_COS_EMBEDDING, + OUT_SIN_EMBEDDING, +}; + +static const uint64_t IN_TENSOR_COUNT = 3; +static const uint64_t OUT_TENSOR_COUNT = 2; +static const uint64_t INTERMEDIATE_TENSOR_COUNT = 0; +static const uint64_t NODE_COUNT = 2; + +atb::Status PositionalEmbeddingGather(atb::Operation **operation) +{ + atb::GraphParam opGraph; + opGraph.inTensorNum = IN_TENSOR_COUNT; + opGraph.outTensorNum = OUT_TENSOR_COUNT; + opGraph.internalTensorNum = INTERMEDIATE_TENSOR_COUNT; + opGraph.name = "PositionalEmbeddingGather"; + + atb::Node cosEmbeddingNode; + atb::infer::GatherParam cosEmbeddingGatherParam; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(cosEmbeddingGatherParam, &cosEmbeddingNode.operation)); + cosEmbeddingNode.inTensorIds = { + PositionalEmbeddingGatherTensorIdx::IN_COS_TABLE, PositionalEmbeddingGatherTensorIdx::IN_POSITION_IDS + }; + cosEmbeddingNode.outTensorIds = {PositionalEmbeddingGatherTensorIdx::OUT_COS_EMBEDDING}; + opGraph.nodes.push_back(cosEmbeddingNode); + + atb::Node sinEmbeddingNode; + atb::infer::GatherParam sinEmbeddingGatherParam; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(sinEmbeddingGatherParam, &sinEmbeddingNode.operation)); + sinEmbeddingNode.inTensorIds = { + PositionalEmbeddingGatherTensorIdx::IN_SIN_TABLE, PositionalEmbeddingGatherTensorIdx::IN_POSITION_IDS + }; + sinEmbeddingNode.outTensorIds = {PositionalEmbeddingGatherTensorIdx::OUT_SIN_EMBEDDING}; + opGraph.nodes.push_back(sinEmbeddingNode); + + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(IN_COS_TABLE); + if (inTensorDescs.at(IN_COS_TABLE).shape.dimNum >= 3) { // 3: 如果IN_COS_TABLE的维度大于3 + outTensorDescs.at(0).shape.dimNum = 3; // 3: 第一个输出tensor的维度为3 + } else { + outTensorDescs.at(0).shape.dimNum = 2; // 2: 第一个输出tensor的维度为2 + } + outTensorDescs.at(0).shape.dims[0] = 1; + // unpadInputs=True场景下,for loop只循环一次;unpadInputs=False场景下,for loop循环两次,将bsz和seqLen合轴 + CHECK_TENSORDESC_DIMNUM_VALID(inTensorDescs.at(IN_POSITION_IDS).shape.dimNum); + for (uint64_t i = 0; i < inTensorDescs.at(IN_POSITION_IDS).shape.dimNum; i++) { + outTensorDescs.at(0).shape.dims[0] = CheckIntMulOverFlow( + outTensorDescs.at(0).shape.dims[0], inTensorDescs.at(IN_POSITION_IDS).shape.dims[i]); + } + + outTensorDescs.at(1) = outTensorDescs.at(0); + return atb::NO_ERROR; + }; + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + return atb::NO_ERROR; +} +static const uint64_t POS_EMB_IN_TENSOR_COUNT = 5; +static const uint64_t POS_EMB_OUT_TENSOR_COUNT = 2; +static const uint64_t POS_EMB_INTERMEDIATE_TENSOR_2D_COUNT = 6; +static const uint64_t POS_EMB_INTERMEDIATE_TENSOR_1D_COUNT = 0; +static const uint64_t POS_EMB_NODE_2D_COUNT = 5; +static const uint64_t POS_EMB_NODE_1D_COUNT = 1; + +static const uint64_t DIM_NUM_1 = 1; +static const uint64_t DIM_NUM_2 = 2; +static const uint64_t DIM_NUM_3 = 3; +static const uint64_t DIM_NUM_4 = 4; +static const int64_t DIM_LAST = -1; +static const uint64_t DIM_0 = 0; +static const uint64_t DIM_1 = 1; +static const uint64_t DIM_2 = 2; +static const uint64_t DIM_3 = 3; +static const uint64_t SPLIT_NUM_2 = 2; +static const uint64_t SPLIT_NUM_3 = 3; + +static void SqueezeRopeIntensor(const atb::Dims &oldShape, atb::Dims &newShape) +{ + if (oldShape.dimNum == DIM_NUM_4) { + newShape.dimNum = DIM_NUM_2; + newShape.dims[0] = CheckIntMulOverFlow(oldShape.dims[0], oldShape.dims[1]); + newShape.dims[1] = CheckIntMulOverFlow(oldShape.dims[DIM_2], oldShape.dims[DIM_3]); + } else if (oldShape.dimNum == DIM_NUM_3) { + newShape.dimNum = DIM_NUM_2; + newShape.dims[0] = CheckIntMulOverFlow(oldShape.dims[0], oldShape.dims[1]); + newShape.dims[1] = oldShape.dims[DIM_2]; + } else { + newShape = oldShape; + } +} + +enum class RotaryPositionEmbeddingTensorId : int { + IN_QUERY = 0, + IN_KEY, + IN_ROPE_COS, + IN_ROPE_SIN, + IN_SEQLEN, + + OUT_QUERY, + OUT_KEY, + + INTERMEDIATE_QCHUNK0, + INTERMEDIATE_QCHUNK1, + INTERMEDIATE_KCHUNK0, + INTERMEDIATE_KCHUNK1, + INTERMEDIATE_QOUT, + INTERMEDIATE_KOUT, +}; + +#define POS_EMB_CAST(x) static_cast(RotaryPositionEmbeddingTensorId::x) + +int64_t AddInferShapeFunc(atb::GraphParam &opGraph, const RotaryPositionEmbeddingParam ¶m) +{ + if (param.isFA) { + opGraph.inferShapeFunc = [=]( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs + ) { + outTensorDescs.at(0) = inTensorDescs.at(POS_EMB_CAST(IN_QUERY)); + outTensorDescs.at(0).shape.dimNum = DIM_NUM_4; + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(POS_EMB_CAST(IN_QUERY)).shape.dims[0]; + outTensorDescs.at(0).shape.dims[1] = inTensorDescs.at(POS_EMB_CAST(IN_QUERY)).shape.dims[1]; + outTensorDescs.at(0).shape.dims[DIM_2] = param.headNum; + outTensorDescs.at(0).shape.dims[DIM_3] = param.headDim; + outTensorDescs.at(1) = inTensorDescs.at(POS_EMB_CAST(IN_KEY)); + outTensorDescs.at(1).shape.dimNum = DIM_NUM_4; + outTensorDescs.at(1).shape.dims[0] = inTensorDescs.at(POS_EMB_CAST(IN_KEY)).shape.dims[0]; + outTensorDescs.at(1).shape.dims[1] = inTensorDescs.at(POS_EMB_CAST(IN_KEY)).shape.dims[1]; + outTensorDescs.at(1).shape.dims[DIM_2] = param.kvHeadNum; + outTensorDescs.at(1).shape.dims[DIM_3] = param.headDim; + return atb::NO_ERROR; + }; + } else { + opGraph.inferShapeFunc = [=] ( + const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs + ) { + outTensorDescs.at(0) = inTensorDescs.at(POS_EMB_CAST(IN_QUERY)); + outTensorDescs.at(0).shape.dimNum = DIM_NUM_3; + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(POS_EMB_CAST(IN_QUERY)).shape.dims[0]; + outTensorDescs.at(0).shape.dims[DIM_1] = param.headNum; + outTensorDescs.at(0).shape.dims[DIM_2] = param.headDim; + outTensorDescs.at(1) = inTensorDescs.at(POS_EMB_CAST(IN_KEY)); + outTensorDescs.at(1).shape.dimNum = DIM_NUM_3; + outTensorDescs.at(1).shape.dims[0] = inTensorDescs.at(POS_EMB_CAST(IN_KEY)).shape.dims[0]; + outTensorDescs.at(1).shape.dims[DIM_1] = param.kvHeadNum; + outTensorDescs.at(1).shape.dims[DIM_2] = param.headDim; + return atb::NO_ERROR; + }; + } + return atb::NO_ERROR; +} + +int64_t AddSplitKV(atb::GraphParam &opGraph, const RotaryPositionEmbeddingParam ¶m) +{ + atb::Node splitQNode; + atb::infer::SplitParam splitQParam; + splitQParam.splitDim = DIM_LAST; + splitQParam.splitNum = SPLIT_NUM_2; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(splitQParam, &splitQNode.operation)); + splitQNode.inTensorIds = {POS_EMB_CAST(IN_QUERY)}; + splitQNode.outTensorIds = {POS_EMB_CAST(INTERMEDIATE_QCHUNK0), POS_EMB_CAST(INTERMEDIATE_QCHUNK1)}; + if (!param.isFA) { + splitQNode.inTensorReshapeFuncs.resize(splitQNode.inTensorIds.size()); + splitQNode.inTensorReshapeFuncs.at(0) = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + UnsqueezeHeadNumHeadDim(oldShape, newShape, param.headNum, param.headDim); + }; + } + opGraph.nodes.push_back(splitQNode); + + atb::Node splitKNode; + atb::infer::SplitParam splitKParam; + splitKParam.splitDim = DIM_LAST; + splitKParam.splitNum = SPLIT_NUM_2; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(splitKParam, &splitKNode.operation)); + splitKNode.inTensorIds = {POS_EMB_CAST(IN_KEY)}; + splitKNode.outTensorIds = {POS_EMB_CAST(INTERMEDIATE_KCHUNK0), POS_EMB_CAST(INTERMEDIATE_KCHUNK1)}; + if (!param.isFA) { + splitKNode.inTensorReshapeFuncs.resize(splitKNode.inTensorIds.size()); + splitKNode.inTensorReshapeFuncs.at(0) = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + UnsqueezeHeadNumHeadDim(oldShape, newShape, param.kvHeadNum, param.headDim); + }; + } + opGraph.nodes.push_back(splitKNode); + return atb::NO_ERROR; +} + +int64_t AddCatLV(atb::GraphParam &opGraph, const RotaryPositionEmbeddingParam ¶m) +{ + atb::Node cat1Node; + atb::infer::ConcatParam cat1Param; + cat1Param.concatDim = DIM_LAST; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(cat1Param, &cat1Node.operation)); + cat1Node.inTensorIds = {POS_EMB_CAST(INTERMEDIATE_QOUT), POS_EMB_CAST(INTERMEDIATE_QCHUNK1)}; + cat1Node.outTensorIds = {POS_EMB_CAST(OUT_QUERY)}; + if (!param.isFA) { + cat1Node.inTensorReshapeFuncs.resize(cat1Node.inTensorIds.size()); + cat1Node.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + UnsqueezeHeadNumHeadDim(oldShape, newShape, param.headNum, param.headDim / 2); // 2: headDim被切分,所以除以2 + }; + } + opGraph.nodes.push_back(cat1Node); + + atb::Node cat2Node; + atb::infer::ConcatParam cat2Param; + cat2Param.concatDim = DIM_LAST; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(cat2Param, &cat2Node.operation)); + cat2Node.inTensorIds = {POS_EMB_CAST(INTERMEDIATE_KOUT), POS_EMB_CAST(INTERMEDIATE_KCHUNK1)}; + cat2Node.outTensorIds = {POS_EMB_CAST(OUT_KEY)}; + if (!param.isFA) { + cat2Node.inTensorReshapeFuncs.resize(cat2Node.inTensorIds.size()); + cat2Node.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + UnsqueezeHeadNumHeadDim(oldShape, newShape, param.kvHeadNum, param.headDim / 2); // 2: headDim被切分,所以除以2 + }; + } + opGraph.nodes.push_back(cat2Node); + return atb::NO_ERROR; +} + +atb::Status RotaryPositionEmbedding(const RotaryPositionEmbeddingParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + opGraph.name = "RotaryPositionEmbedding"; + opGraph.inTensorNum = POS_EMB_IN_TENSOR_COUNT; + opGraph.outTensorNum = POS_EMB_OUT_TENSOR_COUNT; + opGraph.internalTensorNum = param.rotaryType == HALF_ROTARY ? + POS_EMB_INTERMEDIATE_TENSOR_2D_COUNT : POS_EMB_INTERMEDIATE_TENSOR_1D_COUNT; + + if (param.rotaryType == HALF_ROTARY) { + // split q and k to half + CHECK_OPERATION_STATUS_RETURN(AddSplitKV(opGraph, param)); + + atb::Node ropeNode; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(param.ropeParam, &ropeNode.operation)); + ropeNode.inTensorIds = {POS_EMB_CAST(INTERMEDIATE_QCHUNK0), POS_EMB_CAST(INTERMEDIATE_KCHUNK0), + POS_EMB_CAST(IN_ROPE_COS), POS_EMB_CAST(IN_ROPE_SIN), + POS_EMB_CAST(IN_SEQLEN)}; + ropeNode.outTensorIds = {POS_EMB_CAST(INTERMEDIATE_QOUT), POS_EMB_CAST(INTERMEDIATE_KOUT)}; + ropeNode.inTensorReshapeFuncs.resize(ropeNode.inTensorIds.size()); + if (param.isFA) { + ropeNode.inTensorReshapeFuncs.at(DIM_2) = &SqueezeRopeIntensor; + ropeNode.inTensorReshapeFuncs.at(DIM_3) = &SqueezeRopeIntensor; + } else { + ropeNode.inTensorReshapeFuncs.at(DIM_0) = &SqueezeHeadNumHeadDim; + ropeNode.inTensorReshapeFuncs.at(DIM_1) = &SqueezeHeadNumHeadDim; + ropeNode.inTensorReshapeFuncs.at(DIM_2) = &SqueezeHeadNumHeadDim; + ropeNode.inTensorReshapeFuncs.at(DIM_3) = &SqueezeHeadNumHeadDim; + } + opGraph.nodes.push_back(ropeNode); + + CHECK_OPERATION_STATUS_RETURN(AddCatLV(opGraph, param)); + } else { + atb::Node ropeNode; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(param.ropeParam, &ropeNode.operation)); + ropeNode.inTensorIds = {POS_EMB_CAST(IN_QUERY), POS_EMB_CAST(IN_KEY), POS_EMB_CAST(IN_ROPE_COS), + POS_EMB_CAST(IN_ROPE_SIN), POS_EMB_CAST(IN_SEQLEN)}; + ropeNode.outTensorIds = {POS_EMB_CAST(OUT_QUERY), POS_EMB_CAST(OUT_KEY)}; + ropeNode.inTensorReshapeFuncs.resize(ropeNode.inTensorIds.size()); + ropeNode.inTensorReshapeFuncs.at(DIM_2) = &SqueezeRopeIntensor; + ropeNode.inTensorReshapeFuncs.at(DIM_3) = &SqueezeRopeIntensor; + opGraph.nodes.push_back(ropeNode); + } + CHECK_OPERATION_STATUS_RETURN(AddInferShapeFunc(opGraph, param)); + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + + return atb::NO_ERROR; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/embedding/positional_embedding.h b/tests/proftest/layer_test_framework/operations/fusion/embedding/positional_embedding.h new file mode 100644 index 00000000..c67fffe9 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/embedding/positional_embedding.h @@ -0,0 +1,155 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_COMMON_LAYER_POSITIONAL_EMBEDDING_H +#define ATB_SPEED_MODELS_COMMON_LAYER_POSITIONAL_EMBEDDING_H + +#include "nlohmann/json.hpp" +#include "atb/atb_infer.h" +#include "atb_speed/log.h" +#include "operations/fusion/linear/linear_parallel.h" + +namespace atb_speed { +namespace common { +/// An enum to represent the type of rotary position embedding. +enum RotaryType : uint32_t { + /// No rotary position embedding. + NO_ROTARY = 0, + /// 1D rotary position embedding. + HALF_ROTARY, + /// 2D rotary position embedding. + ALL_ROTARY, +}; + +/// A struct defines `positional embedding`'s parameters. +struct RotaryPositionEmbeddingParam { + atb_speed::common::RotaryType rotaryType = ALL_ROTARY; /// The type of rotary position embedding. + bool isFA = true; /// A flag to indicate whether to use the flash attention. + /// Number of attention heads per rank, which equals to `num_attention_heads` // `world_size`. + /// `num_attention_heads` is defined in model_path -> config.json. + int headNum = 0; + /// Hidden size of each attention head, which equals to `hidden_size` / `num_attention_heads`. + /// `hidden_size` and `num_attention_heads` is defined in model_path -> config.json. + int headDim = 0; + /// Number of key-value heads per rank, which equals to `num_key_value_heads` // `world_size`. + /// `num_key_value_heads` is defined in model_path -> config.json if defined, otherwise `num_attention_heads`. + int kvHeadNum = 0; + atb::infer::RopeParam ropeParam; /// Parameters to be passed through to ATB rope operation. +}; + +/// Create a `RotaryPositionEmbedding` operation. +/// +/// This function supports 1d and 2d Rotary Position Embedding (RoPE). RoPE is a positional encoding technique that +/// encodes the relative position of tokens in a sequence by rotating the embedding vectors of the tokens. +/// For more details, check out the Rope paper: `RoFormer: Enhanced Transformer with Rotary Position Embedding`. +/// +/// \param param Parameters of ROPE operation, see `RotaryPositionEmbeddingParam` for more details. +/// \param operation The address to be filled with the created operation object. +/// \return A flag indicating the status of the operation creation. +/// +/// Operation's Inputs: +/// Name | Dtype | Shape | +/// ------------------- | ----- | ----- | +/// query | float16/bfloat16 | [len(all_seq), num_heads, head_dim] in PA case, [bsz, seq_len, num_heads, head_dim] in FA case | +/// key | float16/bfloat16 | same as query's shape | +/// rope_cos | float16/float32/bfloat16, float32 to enable high-precision rope | [len(all_seq), head_dim/2] if and only if ropeParm.rotaryCoeff equals 2, else [len(all_seq), head_dim] | +/// rope_sin | float16/float32/bfloat16, float32 to enable high-precision rope | [len(all_seq), head_dim/2] if and only if ropeParm.rotaryCoeff equals 2, else [len(all_seq), head_dim] | +/// seq_len | int32/uint32 | [bsz] | +/// +/// Operation's Outputs: +/// Name | Dtype | Shape | +/// embedded_query | same as input query | same as input query | +/// embedded_key | same as input key | same as input key | +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_QUERY_ID = 0, +/// IN_KEY_ID, +/// IN_ROPE_COS_ID, +/// IN_ROPE_SIN_ID, +/// IN_SEQ_LEN_ID, +/// OUT_EMBEDDED_QUERY_ID, +/// OUT_EMBEDDED_KEY_ID +/// }; +/// } +/// atb::Node ropeNode; +/// atb::speed::common::RotaryPositionEmbeddingParam ropeParam; +/// // Modify ropeParam's attributes if needed. +/// RotaryPositionEmbedding(ropeParam, &ropeNode.operation); +/// ropeNode.inTensorIds = { +/// IN_QUERY_ID, +/// IN_KEY_ID, +/// IN_ROPE_COS_ID, +/// IN_ROPE_SIN_ID +/// IN_SEQ_LEN_ID +/// }; +/// ropeNode.outTensorIds = { +/// OUT_EMBEDDED_QUERY_ID, +/// OUT_EMBEDDED_KEY_ID +/// }; +/// graph.nodes.push_back(ropeNode); // Add node to its graph. +/// \endcode +atb::Status RotaryPositionEmbedding(const RotaryPositionEmbeddingParam ¶m, atb::Operation **operation); + +/// Create a `PositionalEmbeddingGather` operation. +/// +/// This function get the positional embedding from the cosine table and sine table according to the position index. +/// +/// \param operation The address to be filled with the created operation object. +/// +/// Operation's Inputs: +/// Name | Dtype | Shape | +/// -------------------|-------|-------| +/// position_ids | int64/int32/uint32 | [len(all_seq)] or [bsz, seq_len] | +/// cosine_table | float16/bfloat16 | depends on the model | +/// sine_table | float16/bfloat16 | depends on the model | +/// +/// Operation's Outputs: +/// Name | Dtype | Shape | +/// -------------------|-------|-------| +/// cos_embedding | same as cosine_table | [len(all_seq), cosine_table.shape[:]] or [bsz, seq_len, cosine_table.shape[:]] | +/// sin_embedding | same as sine_table | same as cos_embedding | +/// +/// Example: +/// \code +/// enum TensorIdx: uint32_t { +/// IN_POSITION_IDS_ID = 0, +/// IN_COSINE_TABLE_ID, +/// IN_SINE_TABLE_ID, +/// OUT_COSINE_EMBEDDING_ID, +/// OUT_SINE_EMBEDDING_ID +/// }; +/// std::vector Tensors = {...}; // Prepare tensors here. +/// atb::Operation *op = nullptr; +/// atb_speed::Model::Node positionalEmbeddingGatherNode; +/// CHECK_OPERATION_STATUS_RETURN(atb_speed::common::PositionalEmbeddingGather(&op)); +/// positionalEmbeddingGatherNode.operation.reset(op); +/// // Assume the input and output tensors are already set in graph. +/// positionalEmbeddingGatherNode.inTensors = { +/// Tensors.at(IN_POSITION_IDS_ID), +/// Tensors.at(IN_COSINE_TABLE_ID, +/// Tensors.at(IN_SINE_TABLE_ID) +/// }; +/// positionalEmbeddingGatherNode.outTensors = { +/// Tensors.at(OUT_COS_EMBEDDING_ID), +/// Tensors.at(OUT_SIN_EMBEDDING_ID) +/// }; +/// graph.nodes.push_back(postionalEmbeddingGatherNode); // Add node to its graph. +/// \endcode +atb::Status PositionalEmbeddingGather(atb::Operation **operation); +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/fusion/embedding/word_embedding.cpp b/tests/proftest/layer_test_framework/operations/fusion/embedding/word_embedding.cpp new file mode 100644 index 00000000..e86fc3da --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/embedding/word_embedding.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operations/fusion/embedding/word_embedding.h" +#include "atb_speed/utils/check_util.h" +#include "operations/fusion/linear/linear_parallel.h" + +namespace atb_speed { +namespace common { + +enum WordEmbeddingTensorIdx : uint32_t { + IN_EMBEDDING_WEIGHTS = 0, + IN_INPUT_IDS, + OUT_HIDDEN_STATES, + INTERMEDIATE_GATHER, + INTERMEDIATE_ALLGATHER_OUT_ID, +}; + +static const uint64_t IN_TENSOR_COUNT = 2; +static const uint64_t OUT_TENSOR_COUNT = 1; +static const uint64_t INTERMEDIATE_TENSOR_NO_ALL_GATHER_COUNT = 0; +static const uint64_t INTERMEDIATE_TENSOR_ALL_GATHER_COUNT = 2; +static const uint64_t NODE_NO_ALL_GATHER_COUNT = 1; +static const uint64_t NODE_ALL_GATHER_COUNT = 3; + +atb::Status WordEmbedding(const WordEmbeddingParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + opGraph.inTensorNum = IN_TENSOR_COUNT; + opGraph.outTensorNum = OUT_TENSOR_COUNT; + opGraph.internalTensorNum = param.tensorParallelInfo.worldSize > 1 ? \ + INTERMEDIATE_TENSOR_ALL_GATHER_COUNT : INTERMEDIATE_TENSOR_NO_ALL_GATHER_COUNT; + opGraph.name = "WordEmbedding"; + + atb::Node inputIdEmbeddingNode; + atb::infer::GatherParam inputembedinggatherparam; + inputembedinggatherparam.axis = param.axis; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(inputembedinggatherparam, &inputIdEmbeddingNode.operation)); + inputIdEmbeddingNode.inTensorIds = { + WordEmbeddingTensorIdx::IN_EMBEDDING_WEIGHTS, WordEmbeddingTensorIdx::IN_INPUT_IDS + }; + inputIdEmbeddingNode.outTensorIds = { + param.tensorParallelInfo.worldSize > 1 ? \ + WordEmbeddingTensorIdx::INTERMEDIATE_GATHER : WordEmbeddingTensorIdx::OUT_HIDDEN_STATES + }; + opGraph.nodes.push_back(inputIdEmbeddingNode); + + if (param.tensorParallelInfo.worldSize > 1) { + LinearParallelParam parallelParam; + parallelParam.parallelType = COLUMN_PARALLEL; + parallelParam.tensorParallelInfo = param.tensorParallelInfo; + parallelParam.unpadInputs = param.unpadInputs; + + std::map tensorMap = { + {"intermediate_linear_out", WordEmbeddingTensorIdx::INTERMEDIATE_GATHER}, + {"intermediate_sync_out", WordEmbeddingTensorIdx::INTERMEDIATE_ALLGATHER_OUT_ID}, + {"out", WordEmbeddingTensorIdx::OUT_HIDDEN_STATES}}; + + CHECK_OPERATION_STATUS_RETURN(AddCommunicationOp(parallelParam, opGraph, tensorMap)); + } + + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0).dtype = inTensorDescs.at(IN_EMBEDDING_WEIGHTS).dtype; + outTensorDescs.at(0).format = inTensorDescs.at(IN_EMBEDDING_WEIGHTS).format; + if (param.unpadInputs) { + outTensorDescs.at(0).shape.dimNum = 2; // 2: 第一个输出tensor的维度为2 + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(IN_INPUT_IDS).shape.dims[0]; + outTensorDescs.at(0).shape.dims[1] = CheckIntMulOverFlow( + inTensorDescs.at(IN_EMBEDDING_WEIGHTS).shape.dims[1], param.tensorParallelInfo.worldSize); + } else { + outTensorDescs.at(0).shape.dimNum = 3; // 3: 第一个输出tensor的维度为3 + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(IN_INPUT_IDS).shape.dims[0]; + outTensorDescs.at(0).shape.dims[1] = inTensorDescs.at(IN_INPUT_IDS).shape.dims[1]; + outTensorDescs.at(0).shape.dims[2] = // 2: 第2维 + CheckIntMulOverFlow( + inTensorDescs.at(IN_EMBEDDING_WEIGHTS).shape.dims[1], + param.tensorParallelInfo.worldSize); + } + return atb::NO_ERROR; + }; + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + return atb::NO_ERROR; +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/embedding/word_embedding.h b/tests/proftest/layer_test_framework/operations/fusion/embedding/word_embedding.h new file mode 100644 index 00000000..31c9bf85 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/embedding/word_embedding.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_COMMON_LAYER_WORD_EMBEDDING_H +#define ATB_SPEED_MODELS_COMMON_LAYER_WORD_EMBEDDING_H + +#include "atb/atb_infer.h" +#include "atb_speed/utils/operation_util.h" +#include "operations/fusion/linear/linear_parallel.h" + +namespace atb_speed { +namespace common { +/// A struct defines `WordEmbedding` operation's parameters. +struct WordEmbeddingParam { + /// Whether input tensor is unpadded. + /// If false, input tensor shape is [batch_size, seq_len]. For request shortter than seq_len, it will be padded. + /// If true, input tensor shape is [(seq_len_1 + seq_len_2 + ... + seq_len_n)]. + bool unpadInputs = false; + /// Which axis to gather slices from input tensors. + int axis = 0; + /// A struct defined in `/fusion/linear/linear_parallel.h`. The vocabulary list will be split according to the + /// settings of the struct; under default parameters, even if the model runs on multiple GPUs, + /// the vocabulary will not be split. + atb_speed::common::TensorParallelInfo tensorParallelInfo; +}; + +/// Create `WordEmbedding` graph operation. +/// \param param `WordEmbedding`'s parameters, see `WordEmbeddingParam` for more details. +/// \param operation The address pointer to the `WordEmbedding` operation. +/// +/// Operation's Inputs: +/// Name | Dtype | Shape | +/// ---------------------- | ----- | ----- | +/// embedding_weight | float16/float32/bfloat16/int32/uint32 | [vocab_size, hidden_size] | +/// input_ids | int64/int32/uint32 | no constraint | +/// +/// Operation's Outputs: +/// Name | Dtype | Shape | +/// ---------------------- | ----- | ----- | +/// output | float16/float32/bfloat16/int32/uint32 | [len(all_seq), hidden_size] or [bsz, seq_len, hidden_size] | +/// +/// Example: +/// \code +/// enum TensorIdx: uint32_t { +/// IN_EMBEDDING_WEIGHT_ID = 0, +/// IN_INPUT_IDS_ID, +/// OUT_OUTPUT_ID, +/// }; +/// std::vector Tensors = {...}; // Prepare tensors here. +/// atb::Operation *op = nullptr; +/// atb_speed::Model::Node wordEmbeddingNode; +/// atb_speed::common::WordEmbeddingParam wordEmbeddingParam; +/// // Modify wordEmbeddingParam's attributes if needed. +/// CHECK_OPERATION_STATUS_RETURN(WordEmbedding(wordEmbeddingParam, &op)); +/// wordEmbeddingNode.operation.reset(op); +/// wordEmbeddingNode.inTensors = { +/// Tensors.at(IN_EMBEDDING_WEIGHT_ID), +/// Tensors.at(IN_INPUT_IDS_ID) +/// }; +/// wordEmbeddingNode.outTensors = { +/// Tensors.at(OUT_OUTPUT_ID) +/// }; +/// graph.nodes.push_back(wordEmbeddingNode); // Add node to its graph. +/// \endcode +atb::Status WordEmbedding(const WordEmbeddingParam ¶m, atb::Operation **operation); +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/fusion/infer_shape_functions.cpp b/tests/proftest/layer_test_framework/operations/fusion/infer_shape_functions.cpp new file mode 100644 index 00000000..cffd9d35 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/infer_shape_functions.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb_speed/utils/check_util.h" +#include "operations/aclnn/utils/utils.h" +#include "operations/fusion/infer_shape_functions.h" + +namespace atb_speed { +namespace common { + +void SqueezeHeadNumHeadDim(const atb::Dims &oldShape, atb::Dims &newShape) +{ + newShape = oldShape; + if (oldShape.dimNum >= 2) { // 2: 对输入tensor的后两维进行合并,若维度小于2则不做修改 + newShape.dimNum = oldShape.dimNum - 1; + newShape.dims[newShape.dimNum - 1] = \ + CheckIntMulOverFlow(oldShape.dims[oldShape.dimNum - 2], oldShape.dims[oldShape.dimNum - 1]); // 2: index + } +} + +void UnsqueezeHeadNumHeadDim(const atb::Dims &oldShape, atb::Dims &newShape, int32_t headNum, int32_t headDim) +{ + newShape = oldShape; + if (oldShape.dimNum == 0) { + return; + } + newShape.dimNum = oldShape.dimNum + 1; + newShape.dims[newShape.dimNum - 2] = headNum; // -2: headNum + newShape.dims[newShape.dimNum - 1] = headDim; // -1: headDim +} + +void UnsqueezeAxis(const atb::Dims &oldShape, atb::Dims &newShape, int32_t axis) +{ + newShape = oldShape; + newShape.dimNum = oldShape.dimNum + 1; + newShape.dims[axis] = 1; + for (uint64_t i = axis + 1; i < std::min(newShape.dimNum, static_cast(8)); i++) { // 8: tensor维度上限 + newShape.dims[i] = oldShape.dims[i - 1]; + } +} + +void SqueezeBatchAndSeq(const atb::Dims& oldShape, atb::Dims& newShape) +{ + if (oldShape.dimNum == NUM3) { // 3: If input shape is [B, S, N, D], squeeze it to [B*S, N ,D] + newShape.dimNum = NUM2; + newShape.dims[DIM0] = CheckIntMulOverFlow(oldShape.dims[DIM0], oldShape.dims[DIM1]); + newShape.dims[DIM1] = oldShape.dims[DIM2]; + } else { + newShape = oldShape; + } +} + +void SqueezeBatchAndHiddenSize(const atb::Dims& oldShape, atb::Dims& newShape) +{ + if (oldShape.dimNum == 4) { // 4: 若输入是[B,S,N,D],则合并为[BS,ND] + newShape.dimNum = 2; // 2: [BS,ND] + newShape.dims[0] = CheckIntMulOverFlow(oldShape.dims[0], oldShape.dims[1]); // 0,0,1: [B,S] => [BS] + newShape.dims[1] = CheckIntMulOverFlow(oldShape.dims[2], oldShape.dims[3]); // 1,2,3: [N,D] => [ND] + } else { + newShape = oldShape; + } +} + +void InternlmV2QKVSplit( + const atb::Dims& oldShape, atb::Dims& newShape, int32_t headNum, int32_t kvHeadNum, int32_t headDim) +{ + if (kvHeadNum == 0 || headDim == 0) { + ATB_SPEED_LOG_ERROR("kvHeadNum or headDim is 0 in InternlmV2QKVSplit, " + << "reshape failed, newShape remains the same as oldShape"); + newShape = oldShape; + return; + } + newShape.dimNum = 4; // 4: 新的shape维度为4 + size_t newShapeDimIndex = 0; + size_t oldShapeDimIndex = 0; + newShape.dims[newShapeDimIndex++] = oldShape.dims[oldShapeDimIndex++]; + newShape.dims[newShapeDimIndex++] = \ + oldShape.dims[oldShapeDimIndex++] / (CheckIntMulOverFlow( + (2 + headNum / kvHeadNum), headDim) // 2: k + v linear + ); + if ((2 + headNum / kvHeadNum) // 2: k + v linear + > std::numeric_limits::max()) { + newShape = oldShape; + return; + } + newShape.dims[newShapeDimIndex++] = 2 + headNum / kvHeadNum; // 2: k + v linear + newShape.dims[newShapeDimIndex++] = headDim; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/infer_shape_functions.h b/tests/proftest/layer_test_framework/operations/fusion/infer_shape_functions.h new file mode 100644 index 00000000..2496f452 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/infer_shape_functions.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_MODELS_COMMON_INFER_SHAPE_FUNCTIONS_H +#define ATB_SPEED_MODELS_COMMON_INFER_SHAPE_FUNCTIONS_H + +#include +#include +#include "atb_speed/log.h" + + +namespace atb_speed { +namespace common { + +/// If oldShape dimNum is not larger than 2, do nothing. Otherwise, squeeze the shape from [..., headNum, headDim] +/// to [..., headNum * headDim]. +void SqueezeHeadNumHeadDim(const atb::Dims &oldShape, atb::Dims &newShape); +/// Unsqueeze shape from [..., headNum * headDim] to [..., headNum, headDim]. +void UnsqueezeHeadNumHeadDim(const atb::Dims &oldShape, atb::Dims &newShape, int32_t headNum, int32_t headDim); +/// Unsqueeze shape at `axis`, e.g. [..., x, ...] to [..., 1, x, ...], where x in oldShape is at `axis`. +void UnsqueezeAxis(const atb::Dims &oldShape, atb::Dims &newShape, int32_t axis); +/// If input shape is [B, S, N, D], squeeze it to [B*S, N*D]. +void SqueezeBatchAndHiddenSize(const atb::Dims& oldShape, atb::Dims& newShape); +/// Reshape before spliting packed qkv linear for the InterlmV2 model, from [B, S] +/// to [B, S / ((`headNum` / `kvHeadNum` + 2) * `headDim`), `headNum` / `kvHeadNum` + 2, `headDim`] +void InternlmV2QKVSplit( + const atb::Dims& oldShape, atb::Dims& newShape, int32_t headNum, int32_t kvHeadNum, int32_t headDim); +/// If input shape is [B, S, N, D], squeeze it to [B*S, N ,D] +void SqueezeBatchAndSeq(const atb::Dims& oldShape, atb::Dims& newShape); +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/linear/linear.cpp b/tests/proftest/layer_test_framework/operations/fusion/linear/linear.cpp new file mode 100644 index 00000000..de06a7e0 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/linear/linear.cpp @@ -0,0 +1,690 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "atb_speed/log.h" +#include "atb_speed/utils/check_util.h" +#include "operations/aclnn/ops/w8a16_operation.h" +#include "operations/aclnn/ops/w4a16_operation.h" +#include "operations/aclnn/ops/w4a8_operation.h" +#include "operations/aclnn/ops/w8a8_operation.h" +#include "operations/aclnn/ops/w16a16_operation.h" +#include "operations/aclnn/ops/grouped_matmul_operation.h" +#include "operations/aclnn/ops/dynamic_quant_operation.h" +#include "operations/aclnn/utils/utils.h" +#include "operations/fusion/utils.h" +#include "operations/fusion/linear/linear.h" + +namespace atb_speed { +namespace common { + +// 是否为matmulBackend开启下LINEAR_W8A8_QUANT、LINEAR_W8A8_DEQUANT场景 +bool IsAclnnPerTensor(const FusionLinearParam ¶m) +{ + return param.matmulBackend == atb_speed::common::OpBackend::ACLNN && + (param.quantType == LINEAR_W8A8_QUANT || param.quantType == LINEAR_W8A8_DEQUANT); +} + +// 是否使用aclnn的QuantBatchMatmul接口 +bool UseQuantBatchMatmul(const FusionLinearParam ¶m) +{ + // All机型: dynamic、pdmix + return IsAclnnPerTensor(param) || \ + param.quantType == LINEAR_W8A8_DYNAMIC_QUANT || \ + param.quantType == LINEAR_W8A8_DYNAMIC_DEQUANT || \ + param.quantType == LINEAR_W4A8_DYNAMIC_QUANT || \ + param.quantType == LINEAR_W4A8_DYNAMIC_DEQUANT; +} + +// aclnn QuantBatchMatMul是否外抛DequantBias +bool IsOutDequantBias(const FusionLinearParam ¶m) +{ + bool isBF16 = param.isBF16; + bool isPerTensor = (param.quantType == LINEAR_W8A8_QUANT || param.quantType == LINEAR_W8A8_DEQUANT); + bool isDecode = !param.isPrefill; + bool enableDequantBias = param.enableDequantBias; + return isBF16 && isPerTensor && isDecode && enableDequantBias; +} + +std::map> GetLinearInTensorCandidates() +{ + std::map> linearInTensorCandidates = { + {"default", { + "in_input", "in_weight", "in_scale", "in_offset", "in_descale", "in_bias", "in_compress_idx"} + }, + {"lora", {"in_group_list", "in_lora_a", "in_lora_b"}}, + {"lora_with_mask", {"in_im_mask"}}, + {"addrmsnormdynamicquant", {"dynamic_input_scale"}}, + {"swiglu_quant", {"intermediate_swiglu_dynamic_scale"}}, + {"add_swiglu_quant_sacle_in", {"swiglu_quant_input_scale"}}, + {"flash_comm", { + "send_counts", "sdispls", "send_count", "recv_counts", "rdispls", "recv_count", "fake_ag_shape"} + }, + }; + return linearInTensorCandidates; +} + +std::map> GetLinearIntermediateTensorCandidates() +{ + std::map> linearIntermediateTensorCandidates = { + {"quant_input", {"intermediate_quant_input"}}, + {"lora", {"intermediate_base_linear_out", "intermediate_lora_a_out", "intermediate_lora_b_out"}}, + {"dynamic_quant", {"intermediate_input_scale"}}, + {"lora_with_mask", {"intermediate_im_mask_out"}}, + {"flashComm", {"intermediate_allgather_out"}}, + {"flashComm_dynamic_quant", {"intermediate_allgather_input_scale_out"}}, + }; + return linearIntermediateTensorCandidates; +} + +std::map ConstructLinearTensorMap( + const FusionLinearParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum) +{ + auto linearInTensorCandidates = GetLinearInTensorCandidates(); + auto linearIntermediateTensorCandidates = GetLinearIntermediateTensorCandidates(); + + std::vector inTensorList = {}; + std::vector intermediateTensorList = {}; + std::vector outTensorList = {"out"}; + + // 添加默认的Tensor + AddTensorToList(linearInTensorCandidates, "default", inTensorList); + + if (!param.enableSwigluQuant || !param.isDownLinear || (param.quantType != LINEAR_W8A8_DYNAMIC_DEQUANT + && param.quantType != LINEAR_W4A8_DYNAMIC_DEQUANT)) { + // 添加额外的中间Tensor + if (param.quantType == LINEAR_W8A8_QUANT || param.quantType == LINEAR_W8A8_SC_QUANT + || ((param.quantType == LINEAR_W8A8_DYNAMIC_QUANT || param.quantType == LINEAR_W4A8_DYNAMIC_QUANT) + && !param.enableSwiGLUQuantForSharedExperts)) { + AddTensorToList(linearIntermediateTensorCandidates, "quant_input", intermediateTensorList); + } + // 添加动态量化中间Tensor + if ((param.quantType == LINEAR_W8A8_DYNAMIC_QUANT || param.quantType == LINEAR_W4A8_DYNAMIC_QUANT) + && !param.enableSwiGLUQuantForSharedExperts) { + AddTensorToList(linearIntermediateTensorCandidates, "dynamic_quant", intermediateTensorList); + } + } + if (param.enableSwigluQuant) { + if (param.isDownLinear && param.isPrefill && param.quantType == LINEAR_W8A8_DYNAMIC_DEQUANT) { + AddTensorToList(linearInTensorCandidates, "swiglu_quant", inTensorList); + } + } else { + // Add Flashcomm 1.0 output + if (param.enableFlashComm) { + AddTensorToList(linearIntermediateTensorCandidates, "flashComm", intermediateTensorList); + if (param.quantType == LINEAR_W8A8_DYNAMIC_QUANT || param.quantType == LINEAR_W8A8_DYNAMIC_DEQUANT) { + AddTensorToList(linearIntermediateTensorCandidates, "flashComm_dynamic_quant", + intermediateTensorList); + } + } + // 添加AddRmsNormDynamicQuant的输出 + if (param.quantType == LINEAR_W8A8_DYNAMIC_DEQUANT || param.quantType == LINEAR_W4A8_DYNAMIC_DEQUANT) { + AddTensorToList(linearInTensorCandidates, "addrmsnormdynamicquant", inTensorList); + } + } + + // 添加SwiGLUQuant的输出 + if (param.enableSwiGLUQuantForSharedExperts) { + AddTensorToList(linearInTensorCandidates, "add_swiglu_quant_sacle_in", inTensorList); + } + if (param.enableFlashComm) { + AddTensorToList(linearInTensorCandidates, "flash_comm", inTensorList); + } + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + internalTensorNum = intermediateTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, intermediateTensorList); +} + +int64_t AddElewiseQuant(atb::GraphParam &opGraph, const FusionLinearParam ¶m, + std::map &tensorMap) +{ + if (param.quantType == LINEAR_W8A8_QUANT || param.quantType == LINEAR_W8A8_SC_QUANT) { + // quant + atb::Node inputQuantNode; + atb::infer::ElewiseParam inputQuantParam; + inputQuantParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_QUANT_PER_CHANNEL; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(inputQuantParam, &inputQuantNode.operation)); + inputQuantNode.inTensorIds = GetTensorIdxList(tensorMap, {"in_input", "in_scale", "in_offset"}); + inputQuantNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_quant_input")}; + opGraph.nodes.push_back(inputQuantNode); + } + if (param.quantType == LINEAR_W8A8_DYNAMIC_QUANT || param.quantType == LINEAR_W4A8_DYNAMIC_QUANT) { + atb::Node inputDynamicQuantNode; + inputDynamicQuantNode.inTensorIds = GetTensorIdxList(tensorMap, {"in_input"}); + inputDynamicQuantNode.outTensorIds = GetTensorIdxList(tensorMap, {"intermediate_quant_input", + "intermediate_input_scale"}); + inputDynamicQuantNode.operation = new atb_speed::common::DynamicQuantOperation("DynamicQuantNode"); + opGraph.nodes.push_back(inputDynamicQuantNode); + } + return atb::NO_ERROR; +} + +int64_t AddAllGather(atb::GraphParam &opGraph, const FusionLinearParam ¶m, + std::map &tensorMap) +{ + atb::Node allGatherVNode; + atb::infer::AllGatherVParam allGatherVParam; + allGatherVParam.rank = param.flashCommParallelInfo.rank; + allGatherVParam.rankSize = param.flashCommParallelInfo.worldSize; + allGatherVParam.backend = param.flashCommParallelInfo.backend; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(allGatherVParam, &allGatherVNode.operation)); + allGatherVNode.inTensorIds = {GetTensorIdx( + tensorMap, (param.quantType == LINEAR_W8A8_QUANT || param.quantType == LINEAR_W8A8_DYNAMIC_QUANT + || param.quantType == LINEAR_W8A8_SC_QUANT) ? "intermediate_quant_input" : "in_input")}; + allGatherVNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "recv_count")); + allGatherVNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "send_counts")); + allGatherVNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "sdispls")); + allGatherVNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "fake_ag_shape")); + allGatherVNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_allgather_out")}; + + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + opGraph.nodes.push_back(allGatherVNode); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsAfterComm(opGraph)); + + if (param.quantType == LINEAR_W8A8_DYNAMIC_QUANT || param.quantType == LINEAR_W8A8_DYNAMIC_DEQUANT) { + atb::Node allGatherInputScaleNode; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(allGatherVParam, &allGatherInputScaleNode.operation)); + allGatherInputScaleNode.inTensorIds = {GetTensorIdx( + tensorMap, param.quantType == LINEAR_W8A8_DYNAMIC_QUANT + ? "intermediate_input_scale" : "dynamic_input_scale")}; + allGatherInputScaleNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "send_count")); + allGatherInputScaleNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "recv_counts")); + allGatherInputScaleNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "rdispls")); + allGatherInputScaleNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "fake_ag_shape")); + allGatherInputScaleNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_allgather_input_scale_out")}; + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + opGraph.nodes.push_back(allGatherInputScaleNode); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsAfterComm(opGraph)); + } + return atb::NO_ERROR; +} + +int64_t AddAclNNWeightQuantBatchMatmul(atb::Node &linearNode, const FusionLinearParam ¶m, + std::map &tensorMap) +{ + linearNode.inTensorIds = GetTensorIdxList(tensorMap, { + "in_input", "in_weight", "in_scale", "in_offset" + }); + AclNNWeightQuantBatchMatmulParam aclnnParam; + aclnnParam.transposeB = param.transposeType == TRANSPOSE; + if (param.hasBias) { + linearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_bias")); + aclnnParam.hasBias = true; + } + if (param.quantType == W8A16) { + aclnnParam.quantGroupSize = param.quantGroupSize; + linearNode.operation = new atb_speed::common::W8A16Operation("W8A16LinearNode", aclnnParam); + } else if (param.quantType == W4A16) { + aclnnParam.quantGroupSize = param.quantGroupSize; // W4A16 group size默认为64,此时精度更高 + linearNode.operation = new atb_speed::common::W4A16Operation("W4A16LinearNode", aclnnParam); + } + if (linearNode.operation == nullptr) { + return atb::ERROR_INVALID_GRAPH; + } + return atb::NO_ERROR; +} + +int64_t AddW4A8Matmul(atb::Node &linearNode, const FusionLinearParam ¶m, + std::map &tensorMap) +{ + const bool containingQuant = param.quantType == LINEAR_W4A8_DYNAMIC_QUANT; + AclNNW4A8Param aclnnParam; + std::string key; + if (param.enableSwigluQuant && param.isDownLinear) { + key = "in_input"; + } else { + key = (containingQuant && !param.enableSwiGLUQuantForSharedExperts) ? + "intermediate_quant_input" : "in_input"; + } + std::string inputScaleKey; + if (param.enableSwigluQuant && param.isDownLinear && param.isPrefill && containingQuant) { + inputScaleKey = "intermediate_quant_input_scale"; + } else { + inputScaleKey = !containingQuant ? "dynamic_input_scale" : param.enableSwiGLUQuantForSharedExperts ? + "swiglu_quant_input_scale" : "intermediate_input_scale"; + } + std::vector tensorNames = {key, "in_weight", inputScaleKey, "in_scale", "in_bias"}; + linearNode.inTensorIds = GetTensorIdxList(tensorMap, tensorNames); + ATB_SPEED_LOG_DEBUG("tensorNames: " << tensorNames << "; inTensorIds: " << linearNode.inTensorIds); + aclnnParam.outDataType = param.isBF16 ? ACL_BF16 : ACL_FLOAT16; + linearNode.operation = new atb_speed::common::W4A8Operation("W4A8LinearNode", aclnnParam); + + return atb::NO_ERROR; +} + +int64_t AddAclNNQuantMatmul(atb::Node &linearNode, const FusionLinearParam ¶m, + std::map &tensorMap) +{ + AclNNQuantMatmulParam aclnnQuantMatmulParam; + aclnnQuantMatmulParam.transposeB = param.transposeType == TRANSPOSE; + aclnnQuantMatmulParam.matmulBackend = param.matmulBackend; + std::string key; + if (param.enableSwigluQuant && param.isDownLinear) { + key = "in_input"; + } else if (param.enableFlashComm) { + key = "intermediate_allgather_out"; + } else { + key = (param.quantType == LINEAR_W8A8_QUANT || + (param.quantType == LINEAR_W8A8_DYNAMIC_QUANT + && !param.enableSwiGLUQuantForSharedExperts)) ? + "intermediate_quant_input" : "in_input"; + } + std::string inScaleKey = (param.quantType == LINEAR_W8A8_QUANT || param.quantType == LINEAR_W8A8_DEQUANT) ? + "in_descale" : "in_scale"; + std::string inputScaleKey; + if (param.enableSwigluQuant && param.isDownLinear && + param.isPrefill && param.quantType == LINEAR_W8A8_DYNAMIC_DEQUANT) { + inputScaleKey = "intermediate_swiglu_dynamic_scale"; + } else if (param.enableFlashComm) { + inputScaleKey = "intermediate_allgather_input_scale_out"; + } else { + inputScaleKey = param.quantType == LINEAR_W8A8_DYNAMIC_DEQUANT ? + "dynamic_input_scale" : param.enableSwiGLUQuantForSharedExperts ? "swiglu_quant_input_scale" : + "intermediate_input_scale"; + } + std::vector tensorNames = {key, "in_weight", inScaleKey}; + // per token + if (param.quantType == LINEAR_W8A8_DYNAMIC_QUANT || param.quantType == LINEAR_W8A8_DYNAMIC_DEQUANT) { + tensorNames.push_back(inputScaleKey); + aclnnQuantMatmulParam.hasPerTokenScale = true; + } + + // per tensor场景必需走bias场景 + if (param.hasBias || param.quantType == LINEAR_W8A8_QUANT || param.quantType == LINEAR_W8A8_DEQUANT) { + tensorNames.push_back("in_bias"); + aclnnQuantMatmulParam.hasBias = true; + } + linearNode.inTensorIds = GetTensorIdxList(tensorMap, tensorNames); + ATB_SPEED_LOG_DEBUG("tensorNames: " << tensorNames << "; inTensorIds: " << linearNode.inTensorIds); + linearNode.inTensorReshapeFuncs.resize(linearNode.inTensorIds.size()); + linearNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { // 1: input + newShape.dimNum = 2; // dimNum: 2 + // 开启TURBO_ATTN时, w8a8_pdmix或pertoken场景, canndev算子, input会多出一维(2维) + if (oldShape.dimNum == NUM3) { + newShape.dims[DIM0] = oldShape.dims[DIM0] * oldShape.dims[DIM1]; + newShape.dims[DIM1] = oldShape.dims[DIM2]; + } + }; + // dynamic的inputScaleKey转换 + if (param.quantType == LINEAR_W8A8_DYNAMIC_QUANT || param.quantType == LINEAR_W8A8_DYNAMIC_DEQUANT) { + linearNode.inTensorReshapeFuncs[3] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { // 3: 3号scale + newShape.dimNum = 1; // dimNum: 1 + // 开启TURBO_ATTN时, canndev算子, scale会多出一维(2维) + newShape.dims[0] = oldShape.dimNum == NUM2 ? oldShape.dims[0] * oldShape.dims[1] : oldShape.dims[0]; + }; + } + aclnnQuantMatmulParam.isBF16 = param.isBF16; + aclnnQuantMatmulParam.isOutDequantBias = IsOutDequantBias(param); + linearNode.operation = new atb_speed::common::W8A8Operation("W8A8LinearNode", aclnnQuantMatmulParam); + + return atb::NO_ERROR; +} + +int64_t AddAclNNMatmul(atb::Node &linearNode, const FusionLinearParam ¶m, + std::map &tensorMap) +{ + linearNode.inTensorIds = GetTensorIdxList(tensorMap, { + (param.enableFlashComm) ? + "intermediate_allgather_out" : "in_input", "in_weight"}); + AclNNMatmulParam aclnnMatmulParam; + aclnnMatmulParam.transposeB = param.transposeType == TRANSPOSE; + if (param.hasBias) { + linearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_bias")); + aclnnMatmulParam.hasBias = true; + } + linearNode.operation = new atb_speed::common::W16A16Operation("W16A16LinearNode", aclnnMatmulParam); + return atb::NO_ERROR; +} + +int64_t AddAclNNLinear(atb::Node &linearNode, const FusionLinearParam ¶m, + std::map &tensorMap) +{ + if (param.quantType == LINEAR_W4A8_DYNAMIC_QUANT || param.quantType == LINEAR_W4A8_DYNAMIC_DEQUANT) { + CHECK_OPERATION_STATUS_RETURN(AddW4A8Matmul(linearNode, param, tensorMap)); + return atb::NO_ERROR; + } + if (param.quantType == W8A16 || param.quantType == W4A16) { + CHECK_OPERATION_STATUS_RETURN(AddAclNNWeightQuantBatchMatmul(linearNode, param, tensorMap)); + return atb::NO_ERROR; + } + bool useQuantBatchMatmul = UseQuantBatchMatmul(param); + if (useQuantBatchMatmul) { + ATB_SPEED_LOG_DEBUG("AddAclNNQuantMatmul api: " << param.quantType << "," << useQuantBatchMatmul); + CHECK_OPERATION_STATUS_RETURN(AddAclNNQuantMatmul(linearNode, param, tensorMap)); + return atb::NO_ERROR; + } + if (param.quantType == NO_QUANT) { + CHECK_OPERATION_STATUS_RETURN(AddAclNNMatmul(linearNode, param, tensorMap)); + return atb::NO_ERROR; + } + + return atb::NO_ERROR; +} + +int64_t AddLinear(atb::GraphParam &opGraph, const FusionLinearParam ¶m, + std::map &tensorMap) +{ + atb::Node linearNode; + atb::infer::LinearParam linearParam; + int matmulBackend = param.matmulBackend; + if (param.quantType == NO_QUANT) { + if ((param.isBF16 && IsA2()) || IsA3()) { + matmulBackend = atb_speed::common::OpBackend::ATB; + } + } + linearParam.transposeB = param.transposeType == TRANSPOSE; + if (param.quantType != NO_QUANT) { + linearParam.outDataType = param.isBF16 ? ACL_BF16 : ACL_FLOAT16; + } + if (param.enEin) { + linearParam.matmulType = atb::infer::LinearParam::MATMUL_EIN_SUM; + } + // 设置LinearNode outTensor + linearNode.outTensorIds = {GetTensorIdx(tensorMap, "out")}; + // 稀疏量化 + if (param.quantType == LINEAR_W8A8_SC_DEQUANT || param.quantType == LINEAR_W8A8_SC_QUANT) { + atb::infer::LinearSparseParam linearSparseParam; + linearSparseParam.tilingK = 8; // 8: 稀疏量化系数 + linearSparseParam.tilingN = 8; // 8: 稀疏量化稀疏 + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(linearSparseParam, &linearNode.operation)); + std::string key; + if (param.enableFlashComm) { + key = "intermediate_allgather_out"; + } else { + key = param.quantType == LINEAR_W8A8_SC_DEQUANT ? "in_input" : "intermediate_quant_input"; + } + linearNode.inTensorIds = GetTensorIdxList(tensorMap, { + key, "in_weight", "in_bias", "in_descale", "in_compress_idx" + }); + opGraph.nodes.push_back(linearNode); + return atb::NO_ERROR; + } + // AclNN Linear (W8A16, W4A16, LINEAR_W8A8_DYNAMIC_QUANT, LINEAR_W8A8_DYNAMIC_DEQUANT) + if (param.quantType == W8A16 || param.quantType == W4A16 || UseQuantBatchMatmul(param)) { + CHECK_OPERATION_STATUS_RETURN(AddAclNNLinear(linearNode, param, tensorMap)); + opGraph.nodes.push_back(linearNode); + return atb::NO_ERROR; + } + if (matmulBackend == atb_speed::common::OpBackend::ATB) { + std::string key; + if (param.enableFlashComm) { + key = "intermediate_allgather_out"; + } else { + key = param.quantType == LINEAR_W8A8_QUANT ? "intermediate_quant_input" : "in_input"; + } + // 加速库Linear + if (param.quantType == NO_QUANT && param.hasBias) { + linearParam.hasBias = true; + linearNode.inTensorIds = GetTensorIdxList(tensorMap, {key, "in_weight", "in_bias"}); + } else if (param.quantType == NO_QUANT && !param.hasBias) { + linearParam.hasBias = false; + linearNode.inTensorIds = GetTensorIdxList(tensorMap, {key, "in_weight"}); + } else { + linearParam.hasBias = true; + linearNode.inTensorIds = GetTensorIdxList(tensorMap, { + key, "in_weight", "in_bias", "in_descale" + }); + } + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(linearParam, &linearNode.operation)); + } else { + // AclNN Linear (NO_QUANT) + CHECK_OPERATION_STATUS_RETURN(AddAclNNLinear(linearNode, param, tensorMap)); + } + + opGraph.nodes.push_back(linearNode); + + return atb::NO_ERROR; +} + +atb::Status CreateFusionLinear(const FusionLinearParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + opGraph.name = param.quantType == NO_QUANT ? "LinearNoQuant" : \ + param.quantType == LINEAR_W8A8_DEQUANT || param.quantType == LINEAR_W8A8_SC_DEQUANT ? "LinearDequantOnly" : \ + param.quantType == W8A16 ? "LinearW8A16" : \ + param.quantType == W4A16 ? "LinearW4A16" : "LinearQuant"; + std::map tensorMap = ConstructLinearTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + + if (param.transposeType == TRANSPOSE_INVALID) { + ATB_SPEED_LOG_ERROR("param.transposeType is invalid"); + return atb::ERROR_INVALID_GRAPH; + } + // dense层: enableSwiGLUQuantForSharedExperts 不开 + // down层: 1) 不开 2) 开、非down 3)开、down、非DYNAMIC_DEQUANT + if (!param.enableSwiGLUQuantForSharedExperts && (!param.enableSwigluQuant || !param.isDownLinear \ + || (param.quantType != LINEAR_W8A8_DYNAMIC_DEQUANT && param.quantType != LINEAR_W4A8_DYNAMIC_DEQUANT))) { + CHECK_OPERATION_STATUS_RETURN(AddElewiseQuant(opGraph, param, tensorMap)); + if (param.enableFlashComm) { + CHECK_OPERATION_STATUS_RETURN(AddAllGather(opGraph, param, tensorMap)); + } + } + if (param.enableCVOverlap) { + CHECK_OPERATION_STATUS_RETURN(atb_speed::common::CreateRecordWithoutNodeId( + opGraph, atb_speed::EventAction::PUSH, atb_speed::common::VECTOR_CONTROL)); + CHECK_OPERATION_STATUS_RETURN(atb_speed::common::CreateWaitWithoutNodeId( + opGraph, atb_speed::EventAction::PUSH, atb_speed::common::CUBE_CONTROL)); + } + CHECK_OPERATION_STATUS_RETURN(AddLinear(opGraph, param, tensorMap)); + + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + uint32_t inputIdx = GetTensorIdx(tensorMap, "in_input"); + uint32_t weightIdx = GetTensorIdx(tensorMap, "in_weight"); + uint32_t biasIdx = GetTensorIdx(tensorMap, "in_bias"); + outTensorDescs.at(0).format = inTensorDescs.at(inputIdx).format; + outTensorDescs.at(0).dtype = IsOutDequantBias(param) && param.isThrowDequant ? \ + ACL_INT32 : (param.isBF16 ? ACL_BF16 : ACL_FLOAT16); + outTensorDescs.at(0).shape = inTensorDescs.at(inputIdx).shape; + auto outDimSize = outTensorDescs.at(inputIdx).shape.dimNum; + CHECK_TENSORDESC_DIMNUM_VALID(outDimSize); + int nDim = param.transposeType == TransposeType::TRANSPOSE ? 0 : 1; + + if (param.enableFlashComm) { + uint32_t fakeAgShapeIdx = GetTensorIdx(tensorMap, "fake_ag_shape"); + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(fakeAgShapeIdx).shape.dims[0]; + } + if (param.quantType == LINEAR_W8A8_SC_DEQUANT || param.quantType == LINEAR_W8A8_SC_QUANT) { + outTensorDescs.at(0).shape.dims[outDimSize - 1] = inTensorDescs.at(biasIdx).shape.dims[0]; + } else if (param.quantType == W4A16) { + if (param.transposeType == TransposeType::TRANSPOSE) { + outTensorDescs.at(0).shape.dims[outDimSize - 1] = \ + inTensorDescs.at(weightIdx).shape.dims[0]; // 0: n维shape + } else { + outTensorDescs.at(0).shape.dims[outDimSize - 1] = \ + CheckIntMulOverFlow(inTensorDescs.at(weightIdx).shape.dims[1], 2); // 1, 2: 最后一维shape * 2 + } + } else if (param.quantType == LINEAR_W4A8_DYNAMIC_DEQUANT || param.quantType == LINEAR_W4A8_DYNAMIC_QUANT) { + outTensorDescs.at(0).shape.dims[outDimSize - 1] = \ + CheckIntMulOverFlow(inTensorDescs.at(weightIdx).shape.dims[1], 8); // 8: [m, k] @ [k, n//8] -> [m, n] + } else if (inTensorDescs.at(weightIdx).shape.dimNum == 3) { // 3: dimNum + outTensorDescs.at(0).shape.dims[outDimSize - 1] = inTensorDescs.at(weightIdx).shape.dims[nDim + 1]; + } else if (param.enEin && inTensorDescs.at(weightIdx).shape.dimNum == 4) { // 4: dimNum + outTensorDescs.at(0).shape.dims[outDimSize - 1] = param.transposeType == TransposeType::TRANSPOSE ? \ + inTensorDescs.at(weightIdx).shape.dims[2] : // 2: dimNum + inTensorDescs.at(weightIdx).shape.dims[1] * inTensorDescs.at(weightIdx).shape.dims[3]; // 3: dimNum + } else { + outTensorDescs.at(0).shape.dims[outDimSize - 1] = inTensorDescs.at(weightIdx).shape.dims[nDim]; + } + return atb::NO_ERROR; + }; + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + return atb::NO_ERROR; +} + +std::map ConstructLinearWithLoraTensorMap( + const FusionLinearParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum) +{ + auto linearInTensorCandidates = GetLinearInTensorCandidates(); + auto linearIntermediateTensorCandidates = GetLinearIntermediateTensorCandidates(); + + std::vector inTensorList = {}; + std::vector intermediateTensorList = {}; + std::vector outTensorList = {"out"}; + + // 添加默认的Tensor + AddTensorToList(linearInTensorCandidates, "default", inTensorList); + + // 添加Lora特性的Tensor + if (param.supportLora) { + if (param.useImMask) { + AddTensorToList(linearInTensorCandidates, "lora_with_mask", inTensorList); + AddTensorToList(linearIntermediateTensorCandidates, "lora_with_mask", intermediateTensorList); + } + AddTensorToList(linearInTensorCandidates, "lora", inTensorList); + AddTensorToList(linearIntermediateTensorCandidates, "lora", intermediateTensorList); + } + + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + internalTensorNum = intermediateTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, intermediateTensorList); +} + +int64_t AddImMask(atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node mulNode; + atb::infer::ElewiseParam mulParam; + mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(mulParam, &mulNode.operation)); + mulNode.inTensorIds = GetTensorIdxList(tensorMap, {"in_input", "in_im_mask"}); + mulNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_im_mask_out")}; + opGraph.nodes.push_back(mulNode); + return atb::NO_ERROR; +} + +int64_t AddLoraA(atb::GraphParam &opGraph, const FusionLinearParam ¶m, + std::map &tensorMap) +{ + // 添加Lora A + atb::Node loraALinearNode; + if (param.loraEnableGMM) { + AclNNGroupedMatmulParam aclnnParam; + aclnnParam.transposeB = true; + loraALinearNode.operation = new atb_speed::common::GroupedMatmulOperation("loraALinearNode", aclnnParam); + } else { + CHECK_OPERATION_STATUS_RETURN(CreateFusionLinear(param, &loraALinearNode.operation)); + } + if (param.useImMask) { + loraALinearNode.inTensorIds = GetTensorIdxList(tensorMap, {"intermediate_im_mask_out", "in_lora_a"}); + } else { + loraALinearNode.inTensorIds = GetTensorIdxList(tensorMap, {"in_input", "in_lora_a"}); + } + if (param.loraEnableGMM) { + loraALinearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_group_list")); + } else { + // Lora权重暂不支持量化,以下Index仅为占位符 + loraALinearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_scale")); + loraALinearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_offset")); + loraALinearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_descale")); + loraALinearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_bias")); + loraALinearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_compress_idx")); + } + loraALinearNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_lora_a_out")}; + opGraph.nodes.push_back(loraALinearNode); + return atb::NO_ERROR; +} + +int64_t AddLoraB(atb::GraphParam &opGraph, const FusionLinearParam ¶m, + std::map &tensorMap) +{ + // 添加Lora B + atb::Node loraBLinearNode; + if (param.loraEnableGMM) { + AclNNGroupedMatmulParam aclnnParam; + aclnnParam.transposeB = false; + loraBLinearNode.operation = new atb_speed::common::GroupedMatmulOperation("loraBLinearNode", aclnnParam); + } else { + CHECK_OPERATION_STATUS_RETURN(CreateFusionLinear(param, &loraBLinearNode.operation)); + } + loraBLinearNode.inTensorIds = GetTensorIdxList(tensorMap, {"intermediate_lora_a_out", "in_lora_b"}); + if (param.loraEnableGMM) { + loraBLinearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_group_list")); + } else { + // Lora权重暂不支持量化,以下Index仅为占位符 + loraBLinearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_scale")); + loraBLinearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_offset")); + loraBLinearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_descale")); + loraBLinearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_bias")); + loraBLinearNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_compress_idx")); + } + loraBLinearNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_lora_b_out")}; + opGraph.nodes.push_back(loraBLinearNode); + return atb::NO_ERROR; +} + +atb::Status CreateFusionLinearWithLora(const FusionLinearParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + std::map tensorMap = ConstructLinearWithLoraTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + opGraph.name = "LinearWithLora"; + + // 添加Base模型的Linear + atb::Node baseLinearNode; + atb_speed::common::FusionLinearParam baseLinearParam = param; + baseLinearParam.supportLora = false; + baseLinearParam.loraEnableGMM = false; + CHECK_OPERATION_STATUS_RETURN(CreateFusionLinear(baseLinearParam, &baseLinearNode.operation)); + baseLinearNode.inTensorIds = GetTensorIdxList(tensorMap, { + "in_input", "in_weight", "in_scale", "in_offset", + "in_descale", "in_bias", "in_compress_idx" + }); + baseLinearNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_base_linear_out")}; + opGraph.nodes.push_back(baseLinearNode); + + atb_speed::common::FusionLinearParam loraLinearParam; + loraLinearParam.isBF16 = param.isBF16; + loraLinearParam.hasBias = false; + loraLinearParam.transposeType = TRANSPOSE; + loraLinearParam.loraEnableGMM = param.loraEnableGMM; + loraLinearParam.useImMask = param.useImMask; + if (param.useImMask) { + CHECK_OPERATION_STATUS_RETURN(AddImMask(opGraph, tensorMap)); + } + CHECK_OPERATION_STATUS_RETURN(AddLoraA(opGraph, loraLinearParam, tensorMap)); + loraLinearParam.transposeType = NOT_TRANSPOSE; + CHECK_OPERATION_STATUS_RETURN(AddLoraB(opGraph, loraLinearParam, tensorMap)); + + // 合并Base模型的Linear输出和Lora Linear的输出 + atb::Node addNode; + atb::infer::ElewiseParam addParam; + addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(addParam, &addNode.operation)); + addNode.inTensorIds = GetTensorIdxList(tensorMap, {"intermediate_base_linear_out", "intermediate_lora_b_out"}); + addNode.outTensorIds = {GetTensorIdx(tensorMap, "out")}; + opGraph.nodes.push_back(addNode); + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + return atb::NO_ERROR; +} + +atb::Status FusionLinear(const FusionLinearParam ¶m, atb::Operation **operation) +{ + if (param.supportLora) { + return CreateFusionLinearWithLora(param, operation); + } else { + return CreateFusionLinear(param, operation); + } +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/linear/linear.h b/tests/proftest/layer_test_framework/operations/fusion/linear/linear.h new file mode 100644 index 00000000..3ecd29d8 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/linear/linear.h @@ -0,0 +1,216 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_COMMON_LINEAR_H +#define ATB_SPEED_MODELS_COMMON_LINEAR_H + +#include "atb/atb_infer.h" +#include "atb_speed/utils/operation_util.h" +#include "operations/fusion/utils.h" +#include "operations/fusion/parallel_info.h" + +namespace atb_speed { +namespace common { + + +/// Parameters for the fusion linear module +struct FusionLinearParam { + /// Specifies how the linear module is quantized. + /// Refer to the `LinearQuantType` definition in the operations/utils.h. + LinearQuantType quantType = NO_QUANT; + /// When `isBF16` is true, bfloat16 precision is used; otherwise, float16 precision is used. + bool isBF16 = false; + /// Specifies whether linear module has bias. + bool hasBias = false; + /// A flag indicating whether lora is enabled. + bool supportLora = false; + /// A flag indicating whether a mask is used before applying lora adapter. + bool useImMask = false; + /// A flag indicating whether the group matmul operation is enabled; + /// it should be activated when batch inputs include multiple LoRA adapters + bool loraEnableGMM = false; + /// Defines whether the second matrix in the matmul operation is transposed. + int transposeType = TRANSPOSE; + /// The group size used for dequantizing the weight tensor in the per-group quantization approach + int quantGroupSize = 0; + /// A flag indicating whether to use the atb matmul backend + int matmulBackend = atb_speed::common::OpBackend::ATB; + /// A flag indicating whether to use EinMatmul; + bool enEin = false; + /// A flag indicating whether throw dequant. + bool isThrowDequant = false; + /// A flag indicating the prefill and decode phases + bool isPrefill = false; + /// A flag indicating whether the linear operation throws out dequant operation. + bool enableDequantBias = false; + /// A flag indicating whether the model use cube and vector parallel + bool enableCVOverlap = false; + + bool enableSwiGLUQuantForSharedExperts = false; + /// A flag indicating whether to use swigluQuant + bool enableSwigluQuant = false; + /// A flag indicating whether is down linear + bool isDownLinear = false; + /// for flashcomm 1.0 + TensorParallelInfo flashCommParallelInfo; + bool enableFlashComm = false; +}; + +/// Check whether is w8a8 per tensor and matmulBackend is ACLNN +/// +/// \param FusionLinearParam The linear params to check +/// \return True if matmulBackend is ACLNN under w8a8 per tensor +bool IsAclnnPerTensor(const FusionLinearParam ¶m); + +/// Check whether is w8a8 dynamic scene, plus the conditions include aclnnPerTensor +/// Return A flag if use aclnn operator QuantBatchMatmul +/// +/// \param FusionLinearParam The linear params to check +/// \return True if quantType is DYNAMIC or ACLNN with W8A8 +bool UseQuantBatchMatmul(const FusionLinearParam ¶m); + +/// This function is the main entrance for all types of linear modules. +/// It will call different operations based on the `quantType`. +/// Note that linear module with `quantType` equals to `LINEAR_W8A8_DYNAMIC_DEQUANT` is not implemented yet. +/// +/// \param param Parameters for the fusion linear module +/// \param operation the address of a pointer to a default operation +/// \return A flag that indicates whether operation has been successfully created. +/// +/// Operation's inputs when `quantType` is `NO_QUANT`: +/// Name | Dtype | Shape | +/// ----------------|--------------------------|-------| +/// in_input | float16/bfloat16 | [m,k] | +/// in_weight | float16/bfloat16 | [n,k] if `transposeB` is true; otherwise, [k,n] | +/// in_scale | float16 | [1] | +/// in_offset | float16 | [1] | +/// in_descale | float16 | [1] | +/// in_bias | float16 | [1] | +/// in_compress_idx | float16 | [1] | +/// +/// Operation's inputs when `quantType` is `LINEAR_W8A8_DEQUANT`: +/// Name | Dtype | Shape | +/// ----------------|--------------------------|-------| +/// in_input | int8 | [m,k] | +/// in_weight | int8 | [n,k] if `transposeB` is true; otherwise, [k,n] | +/// in_scale | float16 | [1] | +/// in_offset | float16 | [1] | +/// in_descale | int64 if the output tensor's dtype is float16; float32 if the output tensor's dtype is bfloat16 | [n] | +/// in_bias | int32 | [n] | +/// in_compress_idx | float16 | [1] | +/// +/// Operation's inputs when `quantType` is `LINEAR_W8A8_QUANT`: +/// Name | Dtype | Shape | +/// ----------------|--------------------------|-------| +/// in_input | float16/bfloat16 | [m,k] | +/// in_weight | int8 | [n,k] if `transposeB` is true; otherwise, [k,n] | +/// in_scale | the same dtype as in_input | [1] | +/// in_offset | the same dtype as in_input | [1] | +/// in_descale | int64 if the output tensor's dtype is float16; float32 if the output tensor's dtype is bfloat16 | [n] | +/// in_bias | int32 | [n] | +/// in_compress_idx | float16 | [1] | +/// +/// Operation's inputs when `quantType` is `W4A16`: +/// Name | Dtype | Shape | Description | +/// ----------------|--------------------------|-------|---------| +/// in_input | int8 | [m,k] | | +/// in_weight | int8 | [n,k/2] if `transposeB` is true; otherwise, [k,n/2] | +/// in_scale | the same dtype as the output tensor | [n,1]/[n,ceil(k, group_size)] if `transposeB` is true; otherwise, [1,n]/[ceil(k, group_size),n] | | +/// in_offset | the same dtype as the output tensor | [n,1]/[n,ceil(k, group_size)] if `transposeB` is true; otherwise, [1,n]/[ceil(k, group_size),n] | | +/// in_descale | float16 | [1] | | +/// in_bias | int32 if the output tensor's dtype is float16; bfloat16 if the output tensor's dtype is bfloat16 | [n] | Used when `hasBias` is true. | +/// in_compress_idx | float16 | [1] | | +/// +/// Operation's inputs when `quantType` is `W8A16`: +/// Name | Dtype | Shape | Description | +/// ----------------|--------------------------|-------|---------| +/// in_input | int8 | [m,k] | | +/// in_weight | int8 | [n,k] if `transposeB` is true; otherwise, [k,n] | | +/// in_scale | the same dtype as the output tensor | [n,1]/[n,ceil(k, group_size)] if `transposeB` is true; otherwise, [1,n]/[ceil(k, group_size),n] | | +/// in_offset | the same dtype as the output tensor | [n,1]/[n,ceil(k, group_size)] if `transposeB` is true; otherwise, [1,n]/[ceil(k, group_size),n] | | +/// in_descale | float16 | [1] | | +/// in_bias | int32 if the output tensor's dtype is float16; bfloat16 if the output tensor's dtype is bfloat16 | [n] | Used when `hasBias` is true. | +/// in_compress_idx | float16 | [1] | | +/// +/// Operation's inputs when `quantType` is `LINEAR_W8A8_SC_DEQUANT`: +/// Name | Dtype | Shape | +/// ----------------|--------------------------|-------| +/// in_input | int8 | [m,k] | +/// in_weight | int8 | One dimensional tensor with variable shape | +/// in_scale | float16 | [1] | +/// in_offset | float16 | [1] | +/// in_descale | int64 | [n] | +/// in_bias | int32 | [n] | +/// in_compress_idx | int8 | One dimensional tensor with variable shape | +/// +/// Operation's inputs when `quantType` is `LINEAR_W8A8_SC_QUANT`: +/// Name | Dtype | Shape | +/// ----------------|--------------------------|-------| +/// in_input | float16 | [m,k] | +/// in_weight | int8 | One dimensional tensor with variable shape | +/// in_scale | float16 | [1] | +/// in_offset | float16 | [1] | +/// in_descale | int64 | [n] | +/// in_bias | int32 | [n] | +/// in_compress_idx | int8 | One dimensional tensor with variable shape | +/// +/// Operation's inputs when `quantType` is `LINEAR_W8A8_DYNAMIC_QUANT`: +/// Name | Dtype | Shape | +/// ----------------|--------------------------|-------| +/// in_input | float16 | [m,k] | +/// in_weight | int8 | [n,k] if `transposeB` is true; otherwise, [k,n] | +/// in_scale | float16 | [n,1] | +/// in_offset | float16 | [n,1] | +/// in_descale | float16 | [1] | +/// in_bias | float16 | [1] | +/// in_compress_idx | float16 | [1] | +/// +/// Operation's optional inputs: +/// Name | Dtype | Shape | Condition | +/// ----------------|-----------------|-------------|---------------| +/// in_im_mask | float16 | [m,1] | Required and used when `supportLora` and `useImMask` is true | +/// in_group_list | int64 | [batchSize] | Required when `supportLora` is true and only used when `loraEnableGMM` is true | +/// in_lora_a | float16/bfloat16 | [r,k] if `transposeB` is true; otherwise, [k,r] | Required and used when `supportLora` is true | +/// in_lora_b | float16/bfloat16 | [n,r] if `transposeB` is true; otherwise, [r,n] | Required and used when `supportLora` is true | +/// +/// Operation's Outputs: +/// Name | Dtype | Shape | +/// -------|---------------------|-------| +/// out | float16/bfloat16 | [m,n] | +/// Note that operations with `quantType` equals to `LINEAR_W8A8_DYNAMIC_QUANT`, `LINEAR_W8A8_SC_DEQUANT` +/// and `LINEAR_W8A8_SC_QUANT` do not support bfloat16. +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_WEIGHT, +/// IN_PLACEHOLDER, +/// OUT, +/// }; +/// +/// atb::Node linearNode; +/// atb_speed::common::FusionLinearParam linearParam; +/// // Modify linearParam's attribute if needed. +/// FusionLinear(linearParam, &linearNode.operation); +/// linearNode.inTensorIds = {IN_INPUT, IN_WEIGHT, IN_PLACEHOLDER, IN_PLACEHOLDER, IN_PLACEHOLDER, IN_PLACEHOLDER, IN_PLACEHOLDER}; +/// linearNode.outTensorIds = {OUT}; +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(linearNode); +/// \endcode +atb::Status FusionLinear(const FusionLinearParam ¶m, atb::Operation **operation); +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/linear/linear_parallel.cpp b/tests/proftest/layer_test_framework/operations/fusion/linear/linear_parallel.cpp new file mode 100644 index 00000000..05a89db0 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/linear/linear_parallel.cpp @@ -0,0 +1,662 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "atb_speed/base/event_manager.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/check_util.h" +#include "operations/aclnn/ops/argmax_operation.h" +#include "operations/aclnn/ops/matmul_allreduce_operation.h" +#include "operations/aclnn/ops/max_v2_operation.h" +#include "operations/fusion/utils.h" +#include "operations/fusion/linear/linear_parallel.h" + +namespace atb_speed { +namespace common { + +std::map> GetLinearParallelInTensorCandidates() +{ + std::map> linearPrallelInTensorCandidates = { + {"default", { + "in_input", "in_weight", "in_scale", "in_offset", "in_descale", "in_bias", "in_compress_idx"} + }, + {"reduce_quant", { + "in_reduce_quant_scale", "in_reduce_quant_offset", "in_gather_quant_scale", "in_gather_quant_offset"} + }, + {"lora", {"in_seq_len_cum_sum", "in_lora_a", "in_lora_b"}}, + {"lora_with_mask", {"in_im_mask"}}, + {"swiglu_quant", {"intermediate_swiglu_dynamic_scale"}}, + {"flash_comm", {"send_counts", "sdispls", "recv_count", "fake_rs_shape"}}, + }; + return linearPrallelInTensorCandidates; +} + +std::map> GetLinearParallelIntermediateTensorCandidates() +{ + std::map> linearPrallelIntermediateTensorCandidates = { + {"linear_out", {"intermediate_linear_out"}}, + {"sync_out", {"intermediate_sync_out"}}, + {"quant_out", {"intermediate_quant_out"}}, + {"argmax", {"argmax_out", "argmax_withvalue_out", "transpose_argmax_out", "transpose_argmax_withvalue_out"}}, + {"inner_tp", {"intermediate_inner_tp_input", "intermediate_inner_linear_out"}}, + {"inner_tp_prefill", {"intermediate_tp_allgather"}}, + }; + return linearPrallelIntermediateTensorCandidates; +} + +std::map> GetLinearParallelOutTensorCandidates() +{ + std::map> linearPrallelOutTensorCandidates = { + {"argmax_out", {"all_argmax_out", "all_argmaxwithvalue_out"}}, + }; + return linearPrallelOutTensorCandidates; +} + +bool IsDownDynamicDeQuant(const LinearParallelParam ¶m) +{ + return param.fusionLinearParam.isDownLinear && param.fusionLinearParam.isPrefill && + param.fusionLinearParam.enableSwigluQuant && + (param.fusionLinearParam.quantType == atb_speed::common::LinearQuantType::LINEAR_W8A8_DYNAMIC_DEQUANT || + param.fusionLinearParam.quantType == atb_speed::common::LinearQuantType::LINEAR_W4A8_DYNAMIC_DEQUANT); +} + +std::map ConstructTensorMap( + const LinearParallelParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum, + bool enableLcoc) +{ + auto linearPrallelInTensorCandidates = GetLinearParallelInTensorCandidates(); + auto linearPrallelIntermediateTensorCandidates = GetLinearParallelIntermediateTensorCandidates(); + auto linearPrallelOutTensorCandidates = GetLinearParallelOutTensorCandidates(); + std::vector inTensorList = {}; + std::vector intermediateTensorList = {}; + std::vector outTensorList = {}; + if (!param.isArgmaxlogits) { + outTensorList = {"out"}; + } + + // 添加默认的Tensor + AddTensorToList(linearPrallelInTensorCandidates, "default", inTensorList); + + // 添加额外的中间Tensor + if (enableLcoc) { + if (param.biasAfterSync && !param.isArgmaxlogits) { + AddTensorToList(linearPrallelIntermediateTensorCandidates, "sync_out", intermediateTensorList); + } + } else if (param.innerTensorParallelInfo.rankIds.size() > 1) { // 添加内部通信的Tensor + AddTensorToList(linearPrallelIntermediateTensorCandidates, "inner_tp", intermediateTensorList); + if (param.isPrefill) { + AddTensorToList(linearPrallelIntermediateTensorCandidates, "inner_tp_prefill", intermediateTensorList); + } + } else { + AddTensorToList(linearPrallelIntermediateTensorCandidates, "linear_out", intermediateTensorList); + if (param.tensorParallelInfo.quantType == atb::infer::AllReduceParam::QuantType::QUANT_TYPE_PER_CHANNEL && \ + param.fusionLinearParam.quantType == atb_speed::common::LinearQuantType::NO_QUANT) { + AddTensorToList(linearPrallelIntermediateTensorCandidates, "quant_out", intermediateTensorList); + } + // All gather场景下卡间通信的输出无法原地写 + if (param.parallelType == COLUMN_PARALLEL && !param.isArgmaxlogits) { + AddTensorToList(linearPrallelIntermediateTensorCandidates, "sync_out", intermediateTensorList); + } + } + + // 添加Lora特性的Tensor + if (param.fusionLinearParam.supportLora) { + if (param.fusionLinearParam.useImMask) { + AddTensorToList(linearPrallelInTensorCandidates, "lora_with_mask", inTensorList); + } + AddTensorToList(linearPrallelInTensorCandidates, "lora", inTensorList); + } + // 添加lccl reduce int8特性的Tensor + if (param.tensorParallelInfo.quantType == atb::infer::AllReduceParam::QuantType::QUANT_TYPE_PER_CHANNEL && \ + param.fusionLinearParam.quantType == atb_speed::common::LinearQuantType::NO_QUANT) { + AddTensorToList(linearPrallelInTensorCandidates, "reduce_quant", inTensorList); + } + if (IsDownDynamicDeQuant(param)) { + AddTensorToList(linearPrallelInTensorCandidates, "swiglu_quant", inTensorList); + } + // 添加后处理前置的tensor + if (param.isArgmaxlogits) { + AddTensorToList(linearPrallelIntermediateTensorCandidates, "argmax", intermediateTensorList); + AddTensorToList(linearPrallelOutTensorCandidates, "argmax_out", outTensorList); + } + // 添加flashcomm1.0的tensor + if (param.parallelType == REDUCE_SCATTER) { + AddTensorToList(linearPrallelInTensorCandidates, "flash_comm", inTensorList); + } + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + internalTensorNum = intermediateTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, intermediateTensorList); +} + +int64_t AddAllReduceOp(const LinearParallelParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) +{ + atb::Node allReduceNode; + atb::infer::AllReduceParam allReduceParam; + allReduceParam.rank = param.tensorParallelInfo.rank; + allReduceParam.rankSize = param.tensorParallelInfo.worldSize; + allReduceParam.backend = param.tensorParallelInfo.backend; + allReduceParam.rankTableFile = param.tensorParallelInfo.rankTableFile; + allReduceParam.quantType = param.tensorParallelInfo.quantType; + allReduceParam.outDataType = param.tensorParallelInfo.outDataType; + allReduceParam.commDomain = param.tensorParallelInfo.commDomain; + allReduceParam.hcclComm = param.tensorParallelInfo.hcommInfo; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(allReduceParam, &allReduceNode.operation)); + if (param.tensorParallelInfo.quantType == atb::infer::AllReduceParam::QuantType::QUANT_TYPE_PER_CHANNEL) { + bool isQuant = param.fusionLinearParam.quantType == atb_speed::common::LinearQuantType::NO_QUANT; + std::vector allReduceInTensors = { + isQuant ? "intermediate_quant_out" : "intermediate_linear_out", \ + "in_reduce_quant_scale", "in_gather_quant_offset" + }; + allReduceNode.inTensorIds = {GetTensorIdxList(tensorMap, allReduceInTensors)}; + } else { + allReduceNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_linear_out")}; + } + allReduceNode.outTensorIds = { + GetTensorIdx(tensorMap, param.biasAfterSync ? "intermediate_linear_out" : "out") + }; + opGraph.nodes.push_back(allReduceNode); + + if (param.biasAfterSync) { + atb::Node addNode; + atb::infer::ElewiseParam addParam; + addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(addParam, &addNode.operation)); + addNode.inTensorIds = GetTensorIdxList(tensorMap, { + "intermediate_linear_out", "in_bias" + }); + addNode.outTensorIds = {GetTensorIdx(tensorMap, "out")}; + opGraph.nodes.push_back(addNode); + } + return atb::NO_ERROR; +} + +int64_t AddReduceScatterOp(const LinearParallelParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) +{ + atb::Node reduceScatterVNode; + atb::infer::ReduceScatterVParam reduceScatterVParam; + reduceScatterVParam.rank = param.tensorParallelInfo.rank; + reduceScatterVParam.rankSize = param.tensorParallelInfo.worldSize; + reduceScatterVParam.rankTableFile = param.tensorParallelInfo.rankTableFile; + reduceScatterVParam.backend = param.tensorParallelInfo.backend; + reduceScatterVParam.commDomain = param.tensorParallelInfo.commDomain; + reduceScatterVParam.hcclComm = param.tensorParallelInfo.hcommInfo; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(reduceScatterVParam, &reduceScatterVNode.operation)); + reduceScatterVNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_linear_out")}; + reduceScatterVNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "send_counts")); + reduceScatterVNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "sdispls")); + reduceScatterVNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "recv_count")); + reduceScatterVNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "fake_rs_shape")); + reduceScatterVNode.outTensorIds = {GetTensorIdx(tensorMap, "out")}; + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + opGraph.nodes.push_back(reduceScatterVNode); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + return atb::NO_ERROR; +} + +atb::Status CreateArgmax(atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node argmaxNode; + atb_speed::common::AclNNArgMaxParam argMaxParam; + argMaxParam.keepdim = true; + argmaxNode.operation = new atb_speed::common::ArgMaxOperation("argmaxNode", argMaxParam); + argmaxNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_linear_out")}; + argmaxNode.outTensorIds = {GetTensorIdx(tensorMap, "argmax_out")}; + opGraph.nodes.push_back(argmaxNode); + + return atb::NO_ERROR; +} + +atb::Status CreateArgmaxwithValue(atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node maxNode; + atb_speed::common::AclNNMaxV2Param maxV2Param; + maxV2Param.keepdim = true; + maxNode.operation = new atb_speed::common::MaxV2Operation("maxNode", maxV2Param); + maxNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_linear_out")}; + maxNode.outTensorIds = {GetTensorIdx(tensorMap, "argmax_withvalue_out")}; + opGraph.nodes.push_back(maxNode); + + return atb::NO_ERROR; +} + +int64_t AddCommunicationOp(const LinearParallelParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) +{ + CHECK_OPERATION_STATUS_RETURN(AddDapEventsBeforeComm(opGraph)); + if (param.parallelType == ROW_PARALLEL) { + CHECK_OPERATION_STATUS_RETURN(AddAllReduceOp(param, opGraph, tensorMap)); + } else if (param.parallelType == REDUCE_SCATTER) { + CHECK_OPERATION_STATUS_RETURN(AddReduceScatterOp(param, opGraph, tensorMap)); + } else { + atb::Node allGatherNode; + atb::infer::AllGatherParam allGatherParam; + allGatherParam.rank = param.tensorParallelInfo.rank; + allGatherParam.rankSize = param.tensorParallelInfo.worldSize; + allGatherParam.backend = param.tensorParallelInfo.backend; + allGatherParam.rankTableFile = param.tensorParallelInfo.rankTableFile; + allGatherParam.hcclComm = param.tensorParallelInfo.hcommInfo; + allGatherParam.commDomain = param.tensorParallelInfo.commDomain; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(allGatherParam, &allGatherNode.operation)); + allGatherNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_linear_out")}; + allGatherNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_sync_out")}; + opGraph.nodes.push_back(allGatherNode); + } + CHECK_OPERATION_STATUS_RETURN(AddDapEventsAfterComm(opGraph)); + + if (param.parallelType == COLUMN_PARALLEL) { + atb::Node transposeNode; + atb::infer::TransposeParam transposeParam; + if (param.unpadInputs) { + transposeParam.perm = {1, 0, 2}; + } else { + transposeParam.perm = {1, 2, 0, 3}; + } + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(transposeParam, &transposeNode.operation)); + transposeNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_sync_out")}; + transposeNode.outTensorIds = {GetTensorIdx(tensorMap, "out")}; + opGraph.nodes.push_back(transposeNode); + } + return atb::NO_ERROR; +} + +int64_t AddCommunicationArgmaxOp(const LinearParallelParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) +{ + atb::Node allGatherNode; + atb::infer::AllGatherParam allGatherParam; + allGatherParam.rank = param.tensorParallelInfo.rank; + allGatherParam.rankSize = param.tensorParallelInfo.worldSize; + allGatherParam.backend = param.tensorParallelInfo.backend; + allGatherParam.rankTableFile = param.tensorParallelInfo.rankTableFile; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(allGatherParam, &allGatherNode.operation)); + allGatherNode.inTensorIds = {GetTensorIdx(tensorMap, "argmax_out")}; + allGatherNode.outTensorIds = {GetTensorIdx(tensorMap, "transpose_argmax_out")}; + opGraph.nodes.push_back(allGatherNode); + atb::Node transposeNode; + atb::infer::TransposeParam transposeParam; + transposeParam.perm = {1, 0, 2}; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(transposeParam, &transposeNode.operation)); + transposeNode.inTensorIds = {GetTensorIdx(tensorMap, "transpose_argmax_out")}; + transposeNode.outTensorIds = {GetTensorIdx(tensorMap, "all_argmax_out")}; + opGraph.nodes.push_back(transposeNode); + return atb::NO_ERROR; +} + +int64_t AddCommunicationMaxOp(const LinearParallelParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) +{ + atb::Node allGatherNode; + atb::infer::AllGatherParam allGatherParam; + allGatherParam.rank = param.tensorParallelInfo.rank; + allGatherParam.rankSize = param.tensorParallelInfo.worldSize; + allGatherParam.backend = param.tensorParallelInfo.backend; + allGatherParam.rankTableFile = param.tensorParallelInfo.rankTableFile; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(allGatherParam, &allGatherNode.operation)); + allGatherNode.inTensorIds = {GetTensorIdx(tensorMap, "argmax_withvalue_out")}; + allGatherNode.outTensorIds = {GetTensorIdx(tensorMap, "transpose_argmax_withvalue_out")}; + opGraph.nodes.push_back(allGatherNode); + atb::Node transposeNode; + atb::infer::TransposeParam transposeParam; + transposeParam.perm = {1, 0, 2}; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(transposeParam, &transposeNode.operation)); + transposeNode.inTensorIds = {GetTensorIdx(tensorMap, "transpose_argmax_withvalue_out")}; + transposeNode.outTensorIds = {GetTensorIdx(tensorMap, "all_argmaxwithvalue_out")}; + opGraph.nodes.push_back(transposeNode); + + return atb::NO_ERROR; +} + +atb::Status AddInnerPreAllGatherNode(const LinearParallelParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node allGatherNode; + atb::infer::AllGatherParam allGatherParam; + allGatherParam.rank = param.innerTensorParallelInfo.rank; + allGatherParam.rankSize = param.innerTensorParallelInfo.rankIds.size(); + allGatherParam.backend = param.innerTensorParallelInfo.defaultBackend; + param.innerTensorParallelInfo.InitCommDomain(allGatherParam.hcclComm, allGatherParam.commDomain); + + CHECK_OPERATION_STATUS_RETURN(CreateOperation(allGatherParam, &allGatherNode.operation)); + allGatherNode.inTensorIds = {GetTensorIdx(tensorMap, "in_input")}; + allGatherNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_tp_allgather")}; + opGraph.nodes.push_back(allGatherNode); + ATB_SPEED_LOG_DEBUG("Fusion linear Inner pre AllGather calculation success"); + return atb::NO_ERROR; +} + +atb::Status AddInnerPreAllGatherSliceNode(const LinearParallelParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node sliceNode; + atb::infer::SliceParam sliceParam; + sliceParam.offsets.resize(2); // 2: dimNum + sliceParam.offsets[0] = 0; + sliceParam.offsets[1] = param.innerTensorParallelInfo.rank * param.innerTpShape; + sliceParam.size.resize(2); // 2: dimNum + sliceParam.size[0] = -1; + sliceParam.size[1] = param.innerTpShape; + CreateOperation(sliceParam, &sliceNode.operation); + + sliceNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_tp_allgather")}; + sliceNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_inner_tp_input")}; + + sliceNode.inTensorReshapeFuncs.resize(sliceNode.inTensorIds.size()); + sliceNode.inTensorReshapeFuncs[0] = [=] (const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2: dim num + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + newShape.dims[1] = oldShape.dims[2]; // 2: dim + }; + opGraph.nodes.push_back(sliceNode); + ATB_SPEED_LOG_DEBUG("Inner pre AllGather Slice calculation success"); + return atb::NO_ERROR; +} + +atb::Status AddInnerPreAllToAllNode(const LinearParallelParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node allToAllNode; + atb::infer::AllToAllParam allToAllParam; + allToAllParam.rank = param.innerTensorParallelInfo.rank; + allToAllParam.rankSize = param.innerTensorParallelInfo.rankIds.size(); + allToAllParam.backend = "lccl"; + allToAllParam.transpose = true; + param.innerTensorParallelInfo.InitCommDomain(allToAllParam.hcclComm, allToAllParam.commDomain, "lccl"); + CHECK_OPERATION_STATUS_RETURN(CreateOperation(allToAllParam, &allToAllNode.operation)); + allToAllNode.inTensorIds = {GetTensorIdx(tensorMap, "in_input")}; + allToAllNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_inner_tp_input")}; + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + opGraph.nodes.push_back(allToAllNode); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsAfterComm(opGraph)); + ATB_SPEED_LOG_DEBUG("Inner pre AllToAll calculation success"); + return atb::NO_ERROR; +} + +atb::Status AddInnerPostReduceScatterNode(const LinearParallelParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node reduceScatterNode; + atb::infer::ReduceScatterParam reduceScatterParam; + reduceScatterParam.rank = param.innerTensorParallelInfo.rank; + reduceScatterParam.rankSize = param.innerTensorParallelInfo.rankIds.size(); + reduceScatterParam.backend = param.innerTensorParallelInfo.defaultBackend; + param.innerTensorParallelInfo.InitCommDomain(reduceScatterParam.hcclComm, reduceScatterParam.commDomain); + CHECK_OPERATION_STATUS_RETURN(CreateOperation(reduceScatterParam, &reduceScatterNode.operation)); + reduceScatterNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_inner_linear_out")}; + reduceScatterNode.outTensorIds = {GetTensorIdx(tensorMap, param.tensorParallelInfo.worldSize > 1 ? + "intermediate_linear_out" : "out")}; + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + opGraph.nodes.push_back(reduceScatterNode); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsAfterComm(opGraph)); + ATB_SPEED_LOG_DEBUG("Inner Post Reduce Scatter calculation success"); + return atb::NO_ERROR; +} + +atb::Status AddFusionLinearNode(const LinearParallelParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node linearNode; + atb_speed::common::FusionLinearParam linearParam = param.fusionLinearParam; + CHECK_OPERATION_STATUS_RETURN(FusionLinear(linearParam, &linearNode.operation)); + std::vector linearInTensor = { + param.innerTensorParallelInfo.rankIds.size() > 1 ? "intermediate_inner_tp_input" : "in_input", + "in_weight", "in_scale", "in_offset", "in_descale", "in_bias", "in_compress_idx" + }; + if (param.fusionLinearParam.supportLora) { + if (param.fusionLinearParam.useImMask) { + linearInTensor.push_back("in_im_mask"); + } + linearInTensor.push_back("in_seq_len_cum_sum"); + linearInTensor.push_back("in_lora_a"); + linearInTensor.push_back("in_lora_b"); + } + if (IsDownDynamicDeQuant(param)) { + linearInTensor.push_back("intermediate_swiglu_dynamic_scale"); + } + linearNode.inTensorIds = GetTensorIdxList(tensorMap, linearInTensor); + linearNode.outTensorIds = {GetTensorIdx(tensorMap, param.innerTensorParallelInfo.rankIds.size() > 1 ? + "intermediate_inner_linear_out" : "intermediate_linear_out")}; + opGraph.nodes.push_back(linearNode); + return atb::NO_ERROR; +} + +atb::Status CreateLinearParallelMC2(const LinearParallelParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + std::map tensorMap = ConstructTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum, true); + atb::Node linearParallelNode; + linearParallelNode.operation = new atb_speed::common::MatmulAllreduceOperation("matmulAllReduce", + param.tensorParallelInfo.hcommInfo); + linearParallelNode.inTensorIds = GetTensorIdxList(tensorMap, { + "in_input", "in_weight" + }); + + linearParallelNode.outTensorIds = { + GetTensorIdx(tensorMap, param.biasAfterSync ? "intermediate_sync_out" : "out") + }; + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + opGraph.nodes.push_back(linearParallelNode); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsAfterComm(opGraph)); + + if (param.biasAfterSync) { + atb::Node addNode; + atb::infer::ElewiseParam addParam; + addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(addParam, &addNode.operation)); + addNode.inTensorIds = GetTensorIdxList(tensorMap, { + "intermediate_sync_out", "in_bias" + }); + addNode.outTensorIds = {GetTensorIdx(tensorMap, "out")}; + opGraph.nodes.push_back(addNode); + } + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + return atb::NO_ERROR; +} + +void LinearParallelInferShape(atb::GraphParam &opGraph, const LinearParallelParam ¶m, + std::map &tensorMap) +{ + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + uint32_t inputIdx = GetTensorIdx(tensorMap, "in_input"); + uint32_t weightIdx = GetTensorIdx(tensorMap, "in_weight"); + uint32_t biasIdx = GetTensorIdx(tensorMap, "in_bias"); + uint32_t resultDim = 2; + if (param.isArgmaxlogits) { + outTensorDescs.at(0).dtype = aclDataType::ACL_INT32; + outTensorDescs.at(0).format = inTensorDescs.at(inputIdx).format; + outTensorDescs.at(0).shape.dimNum = resultDim; // 二维 [batch_size,1] + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(inputIdx).shape.dims[0]; + outTensorDescs.at(0).shape.dims[1] = param.worldSize; + outTensorDescs.at(1).dtype = inTensorDescs.at(inputIdx).dtype; + outTensorDescs.at(1).format = inTensorDescs.at(inputIdx).format; + outTensorDescs.at(1).shape.dimNum = resultDim; // 二维 [batch_size,1] + outTensorDescs.at(1).shape.dims[0] = inTensorDescs.at(inputIdx).shape.dims[0]; + outTensorDescs.at(1).shape.dims[1] = param.worldSize; + } else { + outTensorDescs.at(0) = inTensorDescs.at(inputIdx); + if (param.fusionLinearParam.isDownLinear && param.fusionLinearParam.enableSwigluQuant) { + outTensorDescs.at(0).dtype = param.fusionLinearParam.isBF16 ? \ + aclDataType::ACL_BF16 : aclDataType::ACL_FLOAT16; + } + CHECK_TENSORDESC_DIMNUM_VALID(inTensorDescs.at(inputIdx).shape.dimNum); + auto dimLast = inTensorDescs.at(inputIdx).shape.dimNum - 1; + int nDim = param.fusionLinearParam.transposeType == TransposeType::TRANSPOSE ? 0 : 1; + if (param.parallelType == COLUMN_PARALLEL) { + outTensorDescs.at(0).shape.dims[dimLast] = \ + CheckIntMulOverFlow(inTensorDescs.at(weightIdx).shape.dims[nDim], + param.tensorParallelInfo.worldSize); + } else if (param.parallelType == REDUCE_SCATTER) { + uint32_t fakeRsShapeIdx = GetTensorIdx(tensorMap, "fake_rs_shape"); + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(fakeRsShapeIdx).shape.dims[0]; + outTensorDescs.at(0).shape.dims[dimLast] = inTensorDescs.at(weightIdx).shape.dims[nDim]; + } else { + if (param.fusionLinearParam.quantType == LINEAR_W8A8_SC_DEQUANT || \ + param.fusionLinearParam.quantType == LINEAR_W8A8_SC_QUANT) { + outTensorDescs.at(0).shape.dims[dimLast] = inTensorDescs.at(biasIdx).shape.dims[0]; + } else if (param.fusionLinearParam.quantType == W4A16) { + if (param.fusionLinearParam.transposeType == TransposeType::TRANSPOSE) { + outTensorDescs.at(0).shape.dims[dimLast] = \ + inTensorDescs.at(weightIdx).shape.dims[0]; // 0: n维shape + } else { + outTensorDescs.at(0).shape.dims[dimLast] = \ + CheckIntMulOverFlow(inTensorDescs.at(weightIdx).shape.dims[1], 2); // 1, 2: 最后一维shape * 2 + } + } else { + outTensorDescs.at(0).shape.dims[dimLast] = inTensorDescs.at(weightIdx).shape.dims[nDim]; + } + } + } + return atb::NO_ERROR; + }; +} + +atb::Status CreateLinearParallel(const LinearParallelParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + if (param.parallelType == REDUCE_SCATTER) { + opGraph.name = "LinearReduceScatter"; + } else if (param.parallelType == ROW_PARALLEL && !param.biasAfterSync) { + opGraph.name = "LinearRowParallelNoAdd"; + } else { + opGraph.name = param.parallelType == COLUMN_PARALLEL ? "LinearColumnParallel" : "LinearRowParallelAndAdd"; + } + + std::map tensorMap = ConstructTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum, false); + + if (param.innerTensorParallelInfo.rankIds.size() > 1) { + if (param.isPrefill) { + CHECK_OPERATION_STATUS_RETURN(AddInnerPreAllGatherNode(param, opGraph, tensorMap)); + CHECK_OPERATION_STATUS_RETURN(AddInnerPreAllGatherSliceNode(param, opGraph, tensorMap)); + } else { + CHECK_OPERATION_STATUS_RETURN(AddInnerPreAllToAllNode(param, opGraph, tensorMap)); + } + } + + CHECK_OPERATION_STATUS_RETURN(AddFusionLinearNode(param, opGraph, tensorMap)); + + if (param.innerTensorParallelInfo.rankIds.size() > 1) { + CHECK_OPERATION_STATUS_RETURN(AddInnerPostReduceScatterNode(param, opGraph, tensorMap)); + } + + if (param.tensorParallelInfo.worldSize > 1) { + if (param.tensorParallelInfo.quantType == atb::infer::AllReduceParam::QuantType::QUANT_TYPE_PER_CHANNEL && \ + param.fusionLinearParam.quantType == atb_speed::common::LinearQuantType::NO_QUANT) { + atb::Node quantNode; + atb::infer::ElewiseParam quantParam; + quantParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_QUANT_PER_CHANNEL; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(quantParam, &quantNode.operation)); + quantNode.inTensorIds = GetTensorIdxList(tensorMap, { + "intermediate_linear_out", "in_reduce_quant_scale", "in_reduce_quant_offset" + }); + quantNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_quant_out")}; + opGraph.nodes.push_back(quantNode); + } + if (param.isArgmaxlogits) { + CHECK_OPERATION_STATUS_RETURN(CreateArgmax(opGraph, tensorMap)); + CHECK_OPERATION_STATUS_RETURN(CreateArgmaxwithValue(opGraph, tensorMap)); + } + if (!param.isArgmaxlogits) { + CHECK_OPERATION_STATUS_RETURN(AddCommunicationOp(param, opGraph, tensorMap)); + } else { + CHECK_OPERATION_STATUS_RETURN(AddCommunicationArgmaxOp(param, opGraph, tensorMap)); + CHECK_OPERATION_STATUS_RETURN(AddCommunicationMaxOp(param, opGraph, tensorMap)); + } + } + LinearParallelInferShape(opGraph, param, tensorMap); + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + return atb::NO_ERROR; +} + +atb::Status CreateLinearParallelLcoc(const LinearParallelParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + opGraph.name = "LinearParallelLcoc"; + std::map tensorMap = ConstructTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum, true); + ATB_SPEED_LOG_DEBUG("linear parallel lcoc opGraph.inTensorNum " << opGraph.inTensorNum); + ATB_SPEED_LOG_DEBUG("linear parallel lcoc opGraph.outTensorNum " << opGraph.outTensorNum); + ATB_SPEED_LOG_DEBUG("linear parallel lcoc opGraph.internalTensorNum " << opGraph.internalTensorNum); + + atb::Node linearParallelNode; + atb::infer::LinearParallelParam linearParallelParam; + linearParallelParam.transWeight = param.fusionLinearParam.transposeType == TransposeType::TRANSPOSE; + linearParallelParam.rank = param.tensorParallelInfo.rank; + linearParallelParam.rankSize = param.tensorParallelInfo.worldSize; + linearParallelParam.hasResidual = false; + linearParallelParam.backend = "lcoc"; + linearParallelParam.commDomain = param.tensorParallelInfo.commDomain; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(linearParallelParam, &linearParallelNode.operation)); + + linearParallelNode.inTensorIds = GetTensorIdxList(tensorMap, { + "in_input", "in_weight" + }); + linearParallelNode.outTensorIds = { + GetTensorIdx(tensorMap, param.biasAfterSync ? "intermediate_sync_out" : "out") + }; + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + opGraph.nodes.push_back(linearParallelNode); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsAfterComm(opGraph)); + + if (param.biasAfterSync) { + atb::Node addNode; + atb::infer::ElewiseParam addParam; + addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(addParam, &addNode.operation)); + addNode.inTensorIds = GetTensorIdxList(tensorMap, { + "intermediate_sync_out", "in_bias" + }); + addNode.outTensorIds = {GetTensorIdx(tensorMap, "out")}; + opGraph.nodes.push_back(addNode); + } + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + return atb::NO_ERROR; +} + +atb::Status LinearParallel(const LinearParallelParam ¶m, atb::Operation **operation) +{ + if (param.tensorParallelInfo.worldSize <= 1 && param.innerTensorParallelInfo.rankIds.size() <= 1) { + return FusionLinear(param.fusionLinearParam, operation); + } else if (param.parallelType == ROW_PARALLEL) { + if (param.tensorParallelInfo.backend == "hccl" && param.enableMC2 && \ + param.fusionLinearParam.quantType == atb_speed::common::LinearQuantType::NO_QUANT) { + return CreateLinearParallelMC2(param, operation); + } else if (param.tensorParallelInfo.backend == "lccl" && \ + param.supportLcoc && !param.fusionLinearParam.supportLora && \ + param.fusionLinearParam.quantType == atb_speed::common::LinearQuantType::NO_QUANT) { + return CreateLinearParallelLcoc(param, operation); + } + return CreateLinearParallel(param, operation); + } else if (param.parallelType == COLUMN_PARALLEL || param.parallelType == REDUCE_SCATTER) { + return CreateLinearParallel(param, operation); + } else { + ATB_SPEED_LOG_ERROR("LinearParallel operation doesn't support parallelType: " << param.parallelType + << " Possible values are 1 (row parallel) or 2 (column parallel)."); + return atb::ERROR_INVALID_PARAM; + } +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/linear/linear_parallel.h b/tests/proftest/layer_test_framework/operations/fusion/linear/linear_parallel.h new file mode 100644 index 00000000..fc5ec745 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/linear/linear_parallel.h @@ -0,0 +1,133 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ASCEND_SPEED_INFERENCE_COMMON_LINEAR_PARALLEL_H +#define ASCEND_SPEED_INFERENCE_COMMON_LINEAR_PARALLEL_H + +#include +#include "acl/acl.h" +#include "atb_speed/utils/operation_util.h" +#include "operations/fusion/parallel_info.h" +#include "operations/fusion/linear/linear.h" + +namespace atb_speed { +namespace common { + +/// Types of tensor parallelism, categorized based on how weight matrix is split across devices. +enum LinearParallelType : uint32_t { + /// No parallelism is applied to the weight matrix + UNDEFINED = 0, + /// The weight matrix is split along its rows + ROW_PARALLEL, + /// The weight matrix is split along its columns + COLUMN_PARALLEL, + /// The weight matrix is split along its rows (0 dim) + REDUCE_SCATTER, +}; + +/// Parameters for the linear parallel module +struct LinearParallelParam { + /// Parameters for the linear module + atb_speed::common::FusionLinearParam fusionLinearParam; + /// Types of tensor parallelism. Refer to the `LinearParallelType` + /// in the `operations/fusion/linear/linear_parallel.h`. + int parallelType = UNDEFINED; + /// A flag indicating whether to add bias after the all-reduce operation. + bool biasAfterSync = false; + /// A flag that indicates whether the input includes padding. + /// It is applicable only when the `parallelType` is set to `ROW_PARALLEL`. + bool unpadInputs = false; + /// A flag that indicates whether low-latency computation over communication is enabled + bool supportLcoc = false; + bool enableMC2 = false; + /// A flag indicating whether a mask is used before apply lora adapter. + bool useImMask = false; + /// Details about tensor parallelism + bool isArgmaxlogits = false; + /// A flag indicating whether argmax every card logits. + int worldSize = 1; + /// A flag indicating the prefill and decode phases + bool isPrefill = false; + /// The shape for inner tp slice + int innerTpShape = 0; + TensorParallelInfo tensorParallelInfo; + atb_speed::common::ParallelInfo innerTensorParallelInfo; +}; + +/// This function is the main entrance for all types of linear parallel modules. +/// It will call different operations based on the `parallelType`. +/// +/// \param param Parameters for the linear parallel module +/// \param operation the address of a pointer to a default operation +/// \return A flag that indicates whether operation has been successfully created. +/// +/// Operations's inputs and outpus follow the same specifications as the inputs of the linear module. +/// See `operations/fusion/linear/linear.h` for more details. +/// +/// In addtion, this operation supports the following optional inputs. They are required when +/// `param.tensorParallelInfo.quantType == atb::infer::AllReduceParam::QuantType::QUANT_TYPE_PER_CHANNEL` and +/// `param.fusionLinearParam.quantType == atb_speed::common::LinearQuantType::NO_QUANT`. +/// Name | Dtype | Shape | +/// -----------------------|----------|-------| +/// in_reduce_quant_scale | float16 | [k] | +/// in_reduce_quant_offset | int8 | [k] | +/// in_gather_quant_scale | float16 | [k] | +/// in_gather_quant_offset | float16 | [1] | +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_WEIGHT, +/// IN_PLACEHOLDER, +/// OUT, +/// }; +/// +/// atb::Node linearParallelNode; +/// atb_speed::common::LinearParallelParam linearParallelParam; +/// // Modify linearParallelParam's attribute if needed. +/// linearParallelParam.parallelType = atb_speed::common::ROW_PARALLEL; +/// linearParallelParam.tensorParallelInfo,worldSize = 4; +/// LinearParallel(linearParallelParam, &linearParallelNode.operation); +/// linearParallelNode.inTensorIds = {IN_INPUT, IN_WEIGHT, IN_PLACEHOLDER, IN_PLACEHOLDER, IN_PLACEHOLDER, IN_PLACEHOLDER, IN_PLACEHOLDER}; +/// linearParallelNode.outTensorIds = {OUT}; +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(linearParallelNode); +/// \endcode +atb::Status LinearParallel(const LinearParallelParam ¶m, atb::Operation **operation); + +/// This function adds communication operations to the graph. +/// It will call different operations based on the `parallelType`. +/// +/// \param param Parameters for the `LinearParallelParam` module +/// \param opGraph A reference to the graph +/// \param tensorMap Defines all the required tensors for the current graph, with the key representing +/// the input tensor name and the value corresponding to the tensor index. +/// \return A flag indicating whether the operation has been successfully created. +/// +/// This function will use the following tensors if `parallelType` equals to `COLUMN_PARALLEL`: +/// Key in `tensorMap` | Requirements | Dtype | Shape | Description | +/// ------------------------|--------------|------------------|-------|----------| +/// intermediate_linear_out | Required | float16/bfloat16 | [m, n] or [m, k, n] | Hidden states | +/// intermediate_sync_out | ^ | float16/bfloat16 | [group_size, m, n] or [group_size, m, k, n] | Hidden +/// states of all communication groups | out | ^ | ^ | [m, n * group_size] +/// or [m, k, n * group_size] | Hidden states of all communication groups after tensor reorder | +int64_t AddCommunicationOp(const LinearParallelParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap); +} // namespace common +} // namespace atb_speed + +#endif diff --git a/tests/proftest/layer_test_framework/operations/fusion/lmhead/hidden_state_slice.cpp b/tests/proftest/layer_test_framework/operations/fusion/lmhead/hidden_state_slice.cpp new file mode 100644 index 00000000..69eaac5d --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/lmhead/hidden_state_slice.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + */ + +#include "atb_speed/utils/check_util.h" +#include "atb_speed/log.h" +#include "operations/fusion/linear/linear_parallel.h" +#include "operations/fusion/lmhead/hidden_state_slice.h" + +namespace atb_speed { +namespace common { + +enum HiddenStateSliceTensorIdx : uint32_t { + IN_HIDDENSTATES = 0, + OUT_HIDDEN_STATES, +}; + +static const uint64_t IN_TENSOR_COUNT = 1; +static const uint64_t OUT_TENSOR_COUNT = 1; +static const uint64_t NODE_COUNT = 1; +static const uint64_t NUM3 = 3; // num3 + +atb::Status HiddenStateSlice(const HiddenStateSliceParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + opGraph.inTensorNum = IN_TENSOR_COUNT; + opGraph.outTensorNum = OUT_TENSOR_COUNT; + opGraph.internalTensorNum = 0; + + atb::Node sliceNode; + atb::infer::SliceParam sliceParam; + sliceParam.offsets.resize(NUM3); + sliceParam.offsets[0] = param.rank; + sliceParam.offsets[1] = 0; + sliceParam.offsets[2] = 0; // 2: hidden_state + sliceParam.size.resize(NUM3); + sliceParam.size[0] = 1; + sliceParam.size[1] = -1; + sliceParam.size[2] = -1; // 2 + CREATE_OPERATION(sliceParam, &sliceNode.operation); + sliceNode.inTensorIds = {HiddenStateSliceTensorIdx::IN_HIDDENSTATES}; + sliceNode.outTensorIds = {HiddenStateSliceTensorIdx::OUT_HIDDEN_STATES}; + + opGraph.nodes.push_back(sliceNode); + + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0).dtype = inTensorDescs.at(IN_HIDDENSTATES).dtype; + outTensorDescs.at(0).format = inTensorDescs.at(IN_HIDDENSTATES).format; + outTensorDescs.at(0).shape.dimNum = 2; // 2: 第一个输出tensor的维度为2 + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(IN_HIDDENSTATES).shape.dims[0] / param.world_size; + outTensorDescs.at(0).shape.dims[1] = inTensorDescs.at(IN_HIDDENSTATES).shape.dims[1]; + return atb::NO_ERROR; + }; + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + return atb::NO_ERROR; +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/lmhead/hidden_state_slice.h b/tests/proftest/layer_test_framework/operations/fusion/lmhead/hidden_state_slice.h new file mode 100644 index 00000000..9218d286 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/lmhead/hidden_state_slice.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + */ + +#ifndef ATB_SPEED_MODELS_COMMON_LAYER_HIDDEN_STATE_SLICE_H +#define ATB_SPEED_MODELS_COMMON_LAYER_HIDDEN_STATE_SLICE_H + +#include "atb/atb_infer.h" +#include "atb_speed/utils/operation_util.h" +#include "operations/fusion/linear/linear_parallel.h" + +namespace atb_speed { +namespace common { +/// A struct defines `HiddenStateSlice` operation's parameters. +struct HiddenStateSliceParam { + /// The rank of this device in the tensor parallelism communication domain in lmhead. + int rank = 0; + /// The size of the tensor parallelism communication domain in lmhead. + int world_size = 1; +}; + +/// Create `HiddenStateSlice` graph operation. +/// \param param `HiddenStateSlice`'s parameters, see `HiddenStateSliceParam` for more details. +/// \param operation The address pointer to the `HiddenStateSlice` operation. +/// Operation's Inputs: +/// Name | Dtype | Shape | +/// ---------------------- | ----- | ----- | +/// in_hidden_states | float16/float/int8/bool/int32/uint32/bf16 | [all_token_size, vocab_size] | +/// Operation's Outputs: +/// Name | Dtype | Shape | +/// ---------------------- | ----- | ----- | +/// output | float16/float/int8/bool/int32/uint32/bf16 | [token_size, vocab_size] | +atb::Status HiddenStateSlice(const HiddenStateSliceParam ¶m, atb::Operation **operation); +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/fusion/lmhead/lmhead.cpp b/tests/proftest/layer_test_framework/operations/fusion/lmhead/lmhead.cpp new file mode 100644 index 00000000..9f7679f0 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/lmhead/lmhead.cpp @@ -0,0 +1,313 @@ + +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "atb_speed/log.h" +#include "atb_speed/utils/check_util.h" + +#include "operations/aclnn/ops/argmax_operation.h" +#include "operations/fusion/lmhead/parallel_lmhead_all2all.h" +#include "operations/fusion/lmhead/lmhead.h" + +namespace atb_speed { +namespace common { + +enum LmHeadTensorIdx : uint32_t { + IN_HIDDENSTATES = 0, + IN_WEIGHT, + IN_SCALE, + IN_OFFSET, + IN_DESCALE, + IN_BIAS, + IN_COMPRESS_IDX, + IN_INDICES, + IN_LOGITS_OFFSET, + OUT_LOGITS, +}; + +static const uint64_t IN_TENSOR_COUNT = 9; +static const uint64_t OUT_TENSOR_COUNT = 1; + +template +int64_t AddSlice(atb::GraphParam &opGraph, const LmHeadParam ¶m, T &config) +{ + atb::Node sliceNode; + atb::infer::SliceParam slicePassParam; + if (param.unpadInputs) { + slicePassParam.offsets = { + 0, CheckIntMulOverFlow(param.hiddenSizePerAttentionHead, + param.linearParallelParam.tensorParallelInfo.rank) + }; + slicePassParam.size = {-1, param.hiddenSizePerAttentionHead}; + } else { + slicePassParam.offsets = { + 0, 0, CheckIntMulOverFlow(param.hiddenSizePerAttentionHead, + param.linearParallelParam.tensorParallelInfo.rank) + }; + slicePassParam.size = {-1, -1, param.hiddenSizePerAttentionHead}; + } + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(slicePassParam, &sliceNode.operation)); + if (param.gatherAhead) { + sliceNode.inTensorIds = {config.INTERMEDIATE_GATHER_OUT}; + } else { + sliceNode.inTensorIds = {LmHeadTensorIdx::IN_HIDDENSTATES}; + } + sliceNode.outTensorIds = {config.INTERMEDIATE_SLICE_OUT}; + opGraph.nodes.push_back(sliceNode); + + return atb::NO_ERROR; +} + +template +int64_t AddLinearParallel( + atb::GraphParam &opGraph, const LmHeadParam ¶m, + T &config, atb_speed::common::LinearParallelType parallelType) +{ + atb::Node linearParallelNode; + if (param.enableDpOut && param.linearParallelParam.tensorParallelInfo.worldSize > 1 && !param.gatherAhead) { + CHECK_OPERATION_STATUS_RETURN(ParallelLmHeadAllToAll(param, &linearParallelNode.operation)); + } else { + CHECK_OPERATION_STATUS_RETURN(LinearParallel(param.linearParallelParam, &linearParallelNode.operation)); + } + if (parallelType == ROW_PARALLEL) { + linearParallelNode.inTensorIds = { + config.INTERMEDIATE_SLICE_OUT, LmHeadTensorIdx::IN_WEIGHT, LmHeadTensorIdx::IN_SCALE, + LmHeadTensorIdx::IN_OFFSET, LmHeadTensorIdx::IN_DESCALE, LmHeadTensorIdx::IN_BIAS, + LmHeadTensorIdx::IN_COMPRESS_IDX, + }; + } else if (param.gatherAhead) { + linearParallelNode.inTensorIds = { + config.INTERMEDIATE_GATHER_OUT, LmHeadTensorIdx::IN_WEIGHT, LmHeadTensorIdx::IN_SCALE, + LmHeadTensorIdx::IN_OFFSET, LmHeadTensorIdx::IN_DESCALE, LmHeadTensorIdx::IN_BIAS, + LmHeadTensorIdx::IN_COMPRESS_IDX, + }; + } else { + linearParallelNode.inTensorIds = { + LmHeadTensorIdx::IN_HIDDENSTATES, LmHeadTensorIdx::IN_WEIGHT, LmHeadTensorIdx::IN_SCALE, + LmHeadTensorIdx::IN_OFFSET, LmHeadTensorIdx::IN_DESCALE, LmHeadTensorIdx::IN_BIAS, + LmHeadTensorIdx::IN_COMPRESS_IDX, + }; + } + if (!param.linearParallelParam.isArgmaxlogits) { + linearParallelNode.outTensorIds = {LmHeadTensorIdx::OUT_LOGITS}; + } else { + if (param.gatherAhead) { + linearParallelNode.outTensorIds = {LmHeadTensorIdx::OUT_LOGITS + 2, + LmHeadTensorIdx::OUT_LOGITS + 3}; + } else { + linearParallelNode.outTensorIds = {LmHeadTensorIdx::OUT_LOGITS + 1, + LmHeadTensorIdx::OUT_LOGITS + 2}; + } + } + opGraph.nodes.push_back(linearParallelNode); + + return atb::NO_ERROR; +} + +int64_t AddLogitsOffset(atb::GraphParam &opGraph, const LmHeadParam ¶m) +{ + uint32_t argmaxOutId = LmHeadTensorIdx::OUT_LOGITS + 1; + if (param.gatherAhead) { + argmaxOutId = argmaxOutId + 1; + } + uint32_t logitSoffsetId = argmaxOutId + 2; + atb::Node addNode; + atb::infer::ElewiseParam addParam; + addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(addParam, &addNode.operation)); + addNode.inTensorIds = {argmaxOutId, LmHeadTensorIdx::IN_LOGITS_OFFSET}; + addNode.outTensorIds = {logitSoffsetId}; + opGraph.nodes.push_back(addNode); + return atb::NO_ERROR; +} + +int64_t AddArgMax(atb::GraphParam &opGraph, const LmHeadParam ¶m) +{ + uint32_t maxOutId = LmHeadTensorIdx::OUT_LOGITS + 2; + if (param.gatherAhead) { + maxOutId = maxOutId + 1; + } + const uint32_t argmaxMaxlogitsOut = maxOutId + 2; + atb::Node argmaxNode; + atb_speed::common::AclNNArgMaxParam argMaxParam; + argMaxParam.keepdim = true; + argmaxNode.operation = new atb_speed::common::ArgMaxOperation("argmaxNode", argMaxParam); + argmaxNode.inTensorIds = {maxOutId}; + argmaxNode.outTensorIds = {argmaxMaxlogitsOut}; + opGraph.nodes.push_back(argmaxNode); + return atb::NO_ERROR; +} + +int64_t AddGatherLogits(atb::GraphParam &opGraph, const LmHeadParam ¶m) +{ + uint32_t argmaxMaxlogitsOut = LmHeadTensorIdx::OUT_LOGITS + 4; + if (param.gatherAhead) { + argmaxMaxlogitsOut = argmaxMaxlogitsOut + 1; + } + uint32_t logitSoffsetId = argmaxMaxlogitsOut - 1; + uint32_t inT32Logits = argmaxMaxlogitsOut + 1; + atb::Node gatherNode; + atb::infer::GatherParam gatherparam; + gatherparam.axis = 1; + gatherparam.batchDims = 1; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(gatherparam, &gatherNode.operation)); + gatherNode.inTensorIds = {logitSoffsetId, argmaxMaxlogitsOut}; + gatherNode.outTensorIds = {inT32Logits}; + opGraph.nodes.push_back(gatherNode); + + atb::Node castNode; + atb::infer::ElewiseParam castParam; + castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castParam.outTensorType = ACL_INT64; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(castParam, &castNode.operation)); + castNode.inTensorIds = {inT32Logits}; + castNode.outTensorIds = {LmHeadTensorIdx::OUT_LOGITS}; + opGraph.nodes.push_back(castNode); + return atb::NO_ERROR; +} + +template +atb::Status CreateLmHead( + const LmHeadParam ¶m, atb::Operation **operation, T config, + atb_speed::common::LinearParallelType parallelType) +{ + uint32_t RESULT_OFFSET_4 = 4; + uint32_t RESULT_OFFSET_5 = 5; + uint32_t RESULT_DIM_2 = 2; + atb::GraphParam opGraph; + opGraph.inTensorNum = IN_TENSOR_COUNT; + opGraph.outTensorNum = OUT_TENSOR_COUNT; + if (param.linearParallelParam.isArgmaxlogits) { + opGraph.internalTensorNum = + param.gatherAhead ? config.intermediateTensorCount + RESULT_OFFSET_5 + : config.intermediateTensorCount + RESULT_OFFSET_4; + } else { + opGraph.internalTensorNum = + param.gatherAhead ? config.intermediateTensorCount : config.intermediateTensorCount - 1; + } + opGraph.name = "LmHead"; + + if (param.gatherAhead) { + atb::Node gatherNode; + atb::infer::GatherParam gatherParam; + gatherParam.axis = param.unpadInputs ? 0 : 1; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(gatherParam, &gatherNode.operation)); + gatherNode.inTensorIds = {LmHeadTensorIdx::IN_HIDDENSTATES, LmHeadTensorIdx::IN_INDICES}; + gatherNode.outTensorIds = {config.INTERMEDIATE_GATHER_OUT}; + opGraph.nodes.push_back(gatherNode); + } + if (parallelType == ROW_PARALLEL) { + CHECK_OPERATION_STATUS_RETURN(AddSlice(opGraph, param, config)); + } + CHECK_OPERATION_STATUS_RETURN(AddLinearParallel(opGraph, param, config, parallelType)); + if (param.linearParallelParam.isArgmaxlogits) { + CHECK_OPERATION_STATUS_RETURN(AddLogitsOffset(opGraph, param)); + CHECK_OPERATION_STATUS_RETURN(AddArgMax(opGraph, param)); + CHECK_OPERATION_STATUS_RETURN(AddGatherLogits(opGraph, param)); + } + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + if (param.linearParallelParam.isArgmaxlogits) { + outTensorDescs.at(0).format = inTensorDescs.at(IN_HIDDENSTATES).format; + outTensorDescs.at(0).dtype = aclDataType::ACL_INT64; + outTensorDescs.at(0).shape.dimNum = RESULT_DIM_2; // 二维 [batch_size,1] + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(IN_INDICES).shape.dims[0]; + outTensorDescs.at(0).shape.dims[1] = 1; + } else { + outTensorDescs.at(0) = inTensorDescs.at(IN_HIDDENSTATES); + CHECK_TENSORDESC_DIMNUM_VALID(inTensorDescs.at(IN_HIDDENSTATES).shape.dimNum); + auto dimLast = inTensorDescs.at(IN_HIDDENSTATES).shape.dimNum - 1; + if (param.gatherAhead) { + outTensorDescs.at(0).shape.dims[param.unpadInputs ? 0 : 1] = inTensorDescs.at(IN_INDICES).shape.dims[0]; + } else if (param.enableDpOut && param.linearParallelParam.tensorParallelInfo.worldSize > 1) { + outTensorDescs.at(0).shape.dims[0] = \ + inTensorDescs.at(IN_HIDDENSTATES).shape.dims[0] / \ + param.linearParallelParam.tensorParallelInfo.worldSize; + } + if (parallelType == COLUMN_PARALLEL) { + int nDim = + param.linearParallelParam.fusionLinearParam.transposeType == TransposeType::TRANSPOSE ? 0 : 1; + outTensorDescs.at(0).shape.dims[dimLast] = \ + CheckIntMulOverFlow(inTensorDescs.at(IN_WEIGHT).shape.dims[nDim], + param.linearParallelParam.tensorParallelInfo.worldSize); + } else { + outTensorDescs.at(0).shape.dims[dimLast] = inTensorDescs.at(IN_WEIGHT).shape.dims[0]; + } + } + return atb::NO_ERROR; + }; + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + return atb::NO_ERROR; +} + +class LmHeadNoParallelConfig { +public: + uint64_t nodeCount = 2; + uint64_t intermediateTensorCount = 1; + + enum LmHeadNoParallelTensorIdx : uint32_t { + INTERMEDIATE_GATHER_OUT = LmHeadTensorIdx::OUT_LOGITS + 1, + INTERMEDIATE_SLICE_OUT // no usage + }; +}; + +class LmHeadRowParallelConfig { +public: + + uint64_t nodeCount = 3; + uint64_t intermediateTensorCount = 2; + + enum LmHeadRowParallelTensorIdx : uint32_t { + INTERMEDIATE_SLICE_OUT = LmHeadTensorIdx::OUT_LOGITS + 1, + INTERMEDIATE_GATHER_OUT, + }; +}; + +class LmHeadColumnParallelConfig { +public: + + uint64_t nodeCount = 2; + uint64_t intermediateTensorCount = 1; + + enum LmHeadColumnParallelTensorIdx : uint32_t { + INTERMEDIATE_GATHER_OUT = LmHeadTensorIdx::OUT_LOGITS + 1, + INTERMEDIATE_SLICE_OUT // no usage + }; +}; + +atb::Status LmHead(const LmHeadParam ¶m, atb::Operation **operation) +{ + if (param.linearParallelParam.tensorParallelInfo.worldSize <= 1) { + LmHeadNoParallelConfig lmHeadNoParallelConfig; + return CreateLmHead(param, operation, lmHeadNoParallelConfig, UNDEFINED); + } else if (param.linearParallelParam.parallelType == ROW_PARALLEL) { + LmHeadRowParallelConfig lmHeadRowParallelConfig; + return CreateLmHead(param, operation, lmHeadRowParallelConfig, ROW_PARALLEL); + } else if (param.linearParallelParam.parallelType == COLUMN_PARALLEL) { + LmHeadColumnParallelConfig lmHeadColumnParallelConfig; + return CreateLmHead(param, operation, lmHeadColumnParallelConfig, COLUMN_PARALLEL); + } else { + ATB_SPEED_LOG_ERROR("LmHead operation doesn't support parallelType: " + << param.linearParallelParam.parallelType + << " Possible values are 1 (row parallel) or 2 (column parallel)."); + return atb::ERROR_INVALID_PARAM; + } +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/lmhead/lmhead.h b/tests/proftest/layer_test_framework/operations/fusion/lmhead/lmhead.h new file mode 100644 index 00000000..a942af09 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/lmhead/lmhead.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_COMMON_LMHEAD_H +#define ATB_SPEED_MODELS_COMMON_LMHEAD_H + +#include +#include "atb_speed/utils/operation_util.h" +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/linear/linear_parallel.h" + +namespace atb_speed { +namespace common { +/// A struct of `LmHead`'s parameters. +struct LmHeadParam { + /// Whether to use gatherAhead. We recommand to use gatherAhead at prefill stage, only transform tokens needed + /// into logits based on input IN_INDICES, so that reducing memory usage. + /// gatherAhead is not recommended at decode stage for it will cost more time to call gather. + bool gatherAhead = false; + /// Whether input tensor is unpadded. + /// If false, input tensor shape is [batch_size, seq_len]. For request shortter than seq_len, it will be padded. + /// If true, input tensor shape is [(seq_len_1 + seq_len_2 + ... + seq_len_n)]. + bool unpadInputs = false; + /// Hidden size of each attention head in row parallel setting. + /// Only positive if linearParallelParam.parallelType is ROW_PARALLEL. + int hiddenSizePerAttentionHead = 0; + bool enableDpOut = false; + /// Parameters passed to `LinearParallel`, see `operations/fusion/linear/linear_parallel.h` for more details. + atb_speed::common::LinearParallelParam linearParallelParam; +}; + +/// Create an `LmHead` operation. +/// +/// \param param `LmHead`'s parameters, see `LmHeadParam` for more details. +/// \param operation The address to be filled with the created operation object. +/// \return A flag indicating whether the operation is created successfully. +/// +/// Operation's Inputs: +/// Name | Dtype | Shape | Description | +/// --------------|---------| ----- | -------- | +/// hidden_states | float16 or bfloat16 | [len(all seq_len), hidden_size] if unpadInputs is true, otherwise [bsz, seq_len, hidden_size] | / | +/// weight | float16 or bfloat16 | Let origin weight shape be [vocab_size, hidden_size], if linearParallelParam.parallelType is COLUMN_PARALLEL, then [vocab_size / world_size, hidden_size], otherwise [vocab_size, hidden_size / world_size] | / | +/// scale | float16 | [1] | Place holder, not used | +/// offset | float16 | [1] | Place holder, not used | +/// descale | float16 | [1] | Place holder, not used | +/// bias | float16 | [1] | Place holder, not used | +/// indices | int64 | int64 | Optional, only needed when gatherAhead is true | +/// +/// Operation's Outputs: +/// Name | Dtype | Shape | +/// --------------|-------| ----- | +/// logits | float16 or bfloat16 | [len(all seq_len), vocab_size] if unpadInputs is true, otherwise [bsz, seq_len, vocab_size] | +/// +/// Example: +/// \code +/// enum TensorIdx: uint32_t { +/// IN_HIDDEN_STATES_ID = 0, +/// IN_WEIGHT_ID, +/// IN_INDICES_ID, +/// OUT_LOGITS_ID, +/// PLACE_HOLDER_ID, +/// }; +/// std::vector Tensors = {...}; // Prepare tensors here. +/// atb::Operation *op = nullptr; +/// atb_speed::Model::Node lmHeadNode; +/// atb_speed::common::LmHeadParam lmHeadParam; +/// // Modify LmHeadParam's attributes if needed. +/// CHECK_OPERATION_STATUS_RETURN(LmHead(lmHeadParam, &op)); +/// lmHeadNode.operation.reset(op); +/// // Assume the input and output tensors are already set in graph. +/// lmHeadNode.inTensors = { +/// Tensors.at(IN_HIDDEN_STATES_ID), +/// Tensors.at(IN_WEIGHT_ID), +/// Tensors.at(PLACE_HOLDER_ID), +/// Tensors.at(PLACE_HOLDER_ID), +/// Tensors.at(PLACE_HOLDER_ID), +/// Tensors.at(PLACE_HOLDER_ID), +/// Tensors.at(IN_INDICES_ID) +/// }; +/// lmHeadNode.outTensors = { +/// Tensors.at(OUT_LOGITS_ID) +/// } +/// graph.nodes.push_back(lmHeadNode); // Add operation to its graph. +/// \endcode +atb::Status LmHead(const LmHeadParam ¶m, atb::Operation **operation); +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/fusion/lmhead/parallel_lmhead_all2all.cpp b/tests/proftest/layer_test_framework/operations/fusion/lmhead/parallel_lmhead_all2all.cpp new file mode 100644 index 00000000..d1f9d7e1 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/lmhead/parallel_lmhead_all2all.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + */ + +#include + +#include "operations/fusion/utils.h" +#include "parallel_lmhead_all2all.h" + +namespace atb_speed { +namespace common { + +template +atb::Status CreateLmHeadLinearNode(const LmHeadParam ¶m, atb::GraphParam &opGraph, T &config, size_t &nodeId) +{ + atb::Node &lmHeadLinearNode = opGraph.nodes.at(nodeId++); + atb::infer::LinearParam lmHeadLinearParam; + lmHeadLinearParam.transposeB = param.linearParallelParam.fusionLinearParam.transposeType == TRANSPOSE; + if (!lmHeadLinearParam.transposeB) { + ATB_SPEED_LOG_ERROR("The lmhead linear node in lmhead-all2all doesn't support transposeType: " + << param.linearParallelParam.fusionLinearParam.transposeType + << " The value must be " << TRANSPOSE << "."); + return atb::ERROR_INVALID_PARAM; + } + lmHeadLinearParam.hasBias = false; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(lmHeadLinearParam, &lmHeadLinearNode.operation)); + lmHeadLinearNode.inTensorIds = {config.IN_HIDDENSTATES_ID, config.IN_WEIGHT_ID}; + lmHeadLinearNode.outTensorIds = {config.INTERMEDIATE_LMLINEAR_OUT_ID}; + + return atb::NO_ERROR; +} + +template +atb::Status CreateTransPose1Node(const LmHeadParam ¶m, atb::GraphParam &opGraph, T &config, size_t &nodeId) +{ + atb::Node &transPose1Node = opGraph.nodes.at(nodeId++); + atb::infer::TransposeParam transParam1; + transParam1.perm = { 0, 2, 1 }; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(transParam1, &transPose1Node.operation)); + transPose1Node.inTensorIds = { config.INTERMEDIATE_LMLINEAR_OUT_ID }; + transPose1Node.outTensorIds = { config.INTERMEDIATE_TRANS1_OUT_ID }; + transPose1Node.inTensorReshapeFuncs.resize(transPose1Node.inTensorIds.size()); + transPose1Node.inTensorReshapeFuncs.at(0) = [param](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 3: rank, token, vocab_size + newShape.dims[0] = param.linearParallelParam.tensorParallelInfo.worldSize; + newShape.dims[1] = oldShape.dims[0] / param.linearParallelParam.tensorParallelInfo.worldSize; + newShape.dims[2] = oldShape.dims[1]; // 2: vocab_size + }; + return atb::NO_ERROR; +} + +template +atb::Status CreateAllToAllNode(const LmHeadParam ¶m, atb::GraphParam &opGraph, T &config, size_t &nodeId) +{ + atb::Node &allToAllNode = opGraph.nodes.at(nodeId++); + atb::infer::AllToAllParam allToAllParam; + allToAllParam.rank = param.linearParallelParam.tensorParallelInfo.rank; + allToAllParam.rankSize = param.linearParallelParam.tensorParallelInfo.worldSize; + allToAllParam.backend = param.linearParallelParam.tensorParallelInfo.backend; + allToAllParam.hcclComm = param.linearParallelParam.tensorParallelInfo.hcommInfo; + allToAllParam.commDomain = param.linearParallelParam.tensorParallelInfo.commDomain; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(allToAllParam, &allToAllNode.operation)); + allToAllNode.inTensorIds = {config.INTERMEDIATE_TRANS1_OUT_ID}; + allToAllNode.outTensorIds = {config.INTERMEDIATE_ALLTOALLTP_OUT_ID}; + allToAllNode.inTensorReshapeFuncs.resize(allToAllNode.inTensorIds.size()); + allToAllNode.inTensorReshapeFuncs.at(0) = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2: rank* vocab_size, token + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + newShape.dims[1] = oldShape.dims[2]; // 2: token + }; + return atb::NO_ERROR; +} + +template +atb::Status CreateTransPose2Node(const LmHeadParam ¶m, atb::GraphParam &opGraph, T &config, size_t &nodeId) +{ + atb::Node &transPose2Node = opGraph.nodes.at(nodeId++); + atb::infer::TransposeParam trans2Param; + trans2Param.perm = { 1, 0 }; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(trans2Param, &transPose2Node.operation)); + transPose2Node.inTensorIds = { config.INTERMEDIATE_ALLTOALLTP_OUT_ID }; + transPose2Node.outTensorIds = { config.OUT_LOGITS_ID }; + + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + auto dimLast = inTensorDescs.at(0).shape.dimNum - 1; + outTensorDescs.at(0).shape.dims[dimLast] = inTensorDescs.at(1).shape.dims[0] * \ + param.linearParallelParam.tensorParallelInfo.worldSize; + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(0).shape.dims[0] / \ + param.linearParallelParam.tensorParallelInfo.worldSize; + return atb::NO_ERROR; + }; + return atb::NO_ERROR; +} + +template +atb::Status CreateParallelLmHeadAllToAllBase(const LmHeadParam ¶m, atb::Operation **operation, T config) +{ + atb::GraphParam opGraph; + opGraph.inTensorNum = config.inTensorNum; + opGraph.outTensorNum = config.outTensorNum; + opGraph.internalTensorNum = config.interTensorNum; + opGraph.nodes.resize(config.nodeCount); + opGraph.name = "Parallel_LmHead"; + + size_t nodeId = 0; + CHECK_OPERATION_STATUS_RETURN(CreateLmHeadLinearNode(param, opGraph, config, nodeId)); + CHECK_OPERATION_STATUS_RETURN(CreateTransPose1Node(param, opGraph, config, nodeId)); + CHECK_OPERATION_STATUS_RETURN(CreateAllToAllNode(param, opGraph, config, nodeId)); + CHECK_OPERATION_STATUS_RETURN(CreateTransPose2Node(param, opGraph, config, nodeId)); + + CREATE_OPERATION(opGraph, operation); + return atb::NO_ERROR; +} + +class ParallelLmHeadAllToAllConfig { +public: + + uint64_t inTensorNum = 7; + uint64_t outTensorNum = 1; + uint64_t interTensorNum = 3; + uint64_t nodeCount = 4; + + enum ParallelLmHeadAllToAllId : unsigned int { + IN_HIDDENSTATES_ID = 0, + IN_WEIGHT_ID, + IN_SCALE, + IN_OFFSET, + IN_DESCALE, + IN_BIAS, + IN_COMPRESS_IDX, + OUT_LOGITS_ID, + INTERMEDIATE_LMLINEAR_OUT_ID, + INTERMEDIATE_TRANS1_OUT_ID, + INTERMEDIATE_ALLTOALLTP_OUT_ID, + }; +}; + +atb::Status ParallelLmHeadAllToAll(const LmHeadParam ¶m, atb::Operation **operation) +{ + ParallelLmHeadAllToAllConfig parallelLmHeadAllToAllConfig; + return CreateParallelLmHeadAllToAllBase(param, operation, parallelLmHeadAllToAllConfig); +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/lmhead/parallel_lmhead_all2all.h b/tests/proftest/layer_test_framework/operations/fusion/lmhead/parallel_lmhead_all2all.h new file mode 100644 index 00000000..541f5f23 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/lmhead/parallel_lmhead_all2all.h @@ -0,0 +1,18 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + */ + +#ifndef ATB_SPEED_LAYERS_PARALLEL_LMHEAD_ALLTOALL_H +#define ATB_SPEED_LAYERS_PARALLEL_LMHEAD_ALLTOALL_H +#include + +#include "operations/fusion/lmhead/lmhead.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace common { +atb::Status ParallelLmHeadAllToAll(const LmHeadParam ¶m, atb::Operation **operation); +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/mlp/mlp.cpp b/tests/proftest/layer_test_framework/operations/fusion/mlp/mlp.cpp new file mode 100644 index 00000000..04c9c516 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/mlp/mlp.cpp @@ -0,0 +1,701 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "operations/aclnn/ops/dequant_swiglu_quant_operation.h" +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/utils.h" +#include "operations/fusion/mlp/mlp.h" + +namespace atb_speed { +namespace common { + +std::map> GetMlpInTensorCandidates() +{ + std::map> mlpInTensorCandidates = { + {"default", { + "in_input", "in_norm_weight", "in_norm_bias", "in_norm_new_weight", "in_norm_new_bias", + "in_weight_0", "in_scale_0", "in_offset_0", "in_descale_0", "in_bias_0", "in_compress_idx_0", + "in_weight_1", "in_scale_1", "in_offset_1", "in_descale_1", "in_bias_1", "in_compress_idx_1", + "in_weight_down", "in_scale_down", "in_offset_down", "in_descale_down", "in_bias_down", + "in_compress_idx_down"} + }, + {"add_norm", {"in_residual_add"}}, + {"add_rmsnorm_quant", {"in_mlp_scale_fill", "in_mlp_offset_fill"}}, + {"lora", { + "in_seq_len_cum_sum", "in_lora_a_0", "in_lora_b_0", "in_lora_a_1", "in_lora_b_1", + "in_down_lora_a", "in_down_lora_b"} + }, + {"reduce_quant", { + "in_reduce_quant_scale", "in_reduce_quant_offset", + "in_gather_quant_scale", "in_gather_quant_offset"}}, + {"lora_with_mask", {"in_im_mask"}}, + {"flash_comm", { + "send_counts", "sdispls", "send_count", "recv_counts", "rdispls", "recv_count", + "fake_rs_shape", "fake_ag_shape"} + }, + }; + return mlpInTensorCandidates; +} + +std::map> GetMlpOutTensorCandidates() +{ + std::map> mlpOutTensorCandidates = { + {"default", {"out_linear"}}, + {"add_norm", {"out_add"}}, + }; + return mlpOutTensorCandidates; +} + +template +atb::SVector ConstructMlpInTensorList(const MlpParam ¶m) +{ + auto mlpInTensorCandidates = GetMlpInTensorCandidates(); + + atb::SVector inTensorList = {}; + + // 添加默认的Tensor + AddTensorToList(mlpInTensorCandidates, "default", inTensorList); + + // 添加add norm特性的Tensor、添加AddRmsNormQuant特性的Tensor + if (param.enableAddNorm) { + AddTensorToList(mlpInTensorCandidates, "add_rmsnorm_quant", inTensorList); + AddTensorToList(mlpInTensorCandidates, "add_norm", inTensorList); + } + + // 添加lora特性的Tensor + if (param.supportLora) { + if (param.useImMask) { + AddTensorToList(mlpInTensorCandidates, "lora_with_mask", inTensorList); + } + AddTensorToList(mlpInTensorCandidates, "lora", inTensorList); + } + + // 添加lccl reduce int8特性的Tensor + if (param.downLinearTensorParallelInfo.quantType != \ + atb::infer::AllReduceParam::QuantType::QUANT_TYPE_UNDEFINED) { + AddTensorToList(mlpInTensorCandidates, "reduce_quant", inTensorList); + } + + // Add Flashcomm1.0 Tensor + if (param.enableFlashComm) { + AddTensorToList(mlpInTensorCandidates, "flash_comm", inTensorList); + } + return inTensorList; +} + +template +atb::SVector ConstructMlpOutTensorList(const MlpParam ¶m) +{ + auto mlpOutTensorCandidates = GetMlpOutTensorCandidates(); + + atb::SVector outTensorList = {}; + + // 添加outTensor + AddTensorToList(mlpOutTensorCandidates, "default", outTensorList); + if (param.enableAddNorm) { + AddTensorToList(mlpOutTensorCandidates, "add_norm", outTensorList); + } + + return outTensorList; +} + +template +void SetGateUpNormLinearParam(atb_speed::common::NormLinearParam &gateUpNormLinearParam, + const MlpParam ¶m, bool isAntiOutlier) +{ + gateUpNormLinearParam.isAntiOutlier = isAntiOutlier; + gateUpNormLinearParam.fusionLinearParam.quantType = GetLinearQuantType( + param.packQuantType, param.layerLinearQuantType[GATE_LINEAR_INDEX], param.enableNormQuantOp, + param.layerLinearDescs[GATE_LINEAR_INDEX]); + gateUpNormLinearParam.fusionLinearParam.isBF16 = param.isBF16; + gateUpNormLinearParam.fusionLinearParam.hasBias = param.gateUpHasBias; + gateUpNormLinearParam.fusionLinearParam.supportLora = param.supportLora; + gateUpNormLinearParam.fusionLinearParam.useImMask = param.useImMask; + gateUpNormLinearParam.fusionLinearParam.loraEnableGMM = param.loraEnableGMM; + gateUpNormLinearParam.fusionLinearParam.transposeType = param.layerLinearTransposeType[GATE_LINEAR_INDEX]; + gateUpNormLinearParam.fusionLinearParam.quantGroupSize = param.quantGroupSize; + gateUpNormLinearParam.fusionLinearParam.matmulBackend = param.matmulBackend; + gateUpNormLinearParam.fusionLinearParam.isPrefill = param.isPrefill; + gateUpNormLinearParam.skipNorm = param.skipNorm; + gateUpNormLinearParam.normHasBias = param.normHasBias; + gateUpNormLinearParam.enableAddNorm = param.enableAddNorm; + gateUpNormLinearParam.normParamType = param.normParamType; + gateUpNormLinearParam.normQuantParamType = param.normQuantParamType; + gateUpNormLinearParam.fusionLinearParam.enableFlashComm = param.enableFlashComm; + gateUpNormLinearParam.fusionLinearParam.flashCommParallelInfo.worldSize = + param.downLinearTensorParallelInfo.worldSize; + gateUpNormLinearParam.fusionLinearParam.flashCommParallelInfo.rank = + param.downLinearTensorParallelInfo.rank; + gateUpNormLinearParam.fusionLinearParam.flashCommParallelInfo.backend = + param.downLinearTensorParallelInfo.backend; + bool gateUpIsQuant = IsLinearDescQuant(param, GATE_LINEAR_INDEX); + bool downIsQuant = IsLinearDescQuant(param, DOWN_LINEAR_INDEX); + if (param.enableSwigluQuant && gateUpIsQuant && downIsQuant + && UseQuantBatchMatmul(gateUpNormLinearParam.fusionLinearParam) && !param.isPrefill) { + gateUpNormLinearParam.fusionLinearParam.isThrowDequant = true; // Linear out int_32 + gateUpNormLinearParam.fusionLinearParam.enableSwigluQuant = false; + } +} + +template +atb::Status AddMlpNormLinearGateUp(const MlpParam ¶m, + bool isAntiOutlier, atb::GraphOpBuilder* &graphBuilder) +{ + atb::Operation* normLinearGateUpOp = nullptr; + atb_speed::common::NormLinearParam gateUpNormLinearParam; + SetGateUpNormLinearParam(gateUpNormLinearParam, param, isAntiOutlier); + CHECK_OPERATION_STATUS_RETURN(NormLinear(gateUpNormLinearParam, &normLinearGateUpOp)); + + atb::SVector gateUpInTensorNames = { + "in_input", "in_norm_weight", "in_norm_bias", "in_norm_new_weight", "in_norm_new_bias", + "in_weight_0", "in_scale_0", "in_offset_0", "in_descale_0", "in_bias_0", "in_compress_idx_0", + }; + if (param.enableAddNorm) { + gateUpInTensorNames.push_back("in_mlp_scale_fill"); + gateUpInTensorNames.push_back("in_mlp_offset_fill"); + gateUpInTensorNames.push_back("in_residual_add"); + } + if (param.supportLora) { + if (param.useImMask) { + gateUpInTensorNames.push_back("in_im_mask"); + } + gateUpInTensorNames.push_back("in_seq_len_cum_sum"); + gateUpInTensorNames.push_back("in_lora_a_0"); + gateUpInTensorNames.push_back("in_lora_b_0"); + } + if (param.enableFlashComm) { + gateUpInTensorNames.push_back("send_counts"); + gateUpInTensorNames.push_back("sdispls"); + gateUpInTensorNames.push_back("send_count"); + gateUpInTensorNames.push_back("recv_counts"); + gateUpInTensorNames.push_back("rdispls"); + gateUpInTensorNames.push_back("recv_count"); + gateUpInTensorNames.push_back("fake_ag_shape"); + } + atb::SVector gateUpOutTensorNames = {}; + if (param.mlpPackType == MlpPackType::GATE_UP_WEIGHT_PACK) { + gateUpOutTensorNames = {"intermediate_gate_up"} ; + } else if (param.mlpPackType == MlpPackType::GATE_UP_WEIGHT_NO_PACK) { + gateUpOutTensorNames = {"intermediate_gate"}; + } else { + gateUpOutTensorNames = {"intermediate_up"}; + } + if (param.enableAddNorm) { + gateUpOutTensorNames.push_back("out_add"); + } + + graphBuilder->AddOperation(normLinearGateUpOp, gateUpInTensorNames, gateUpOutTensorNames); + return atb::NO_ERROR; +} + +template +void SetUpNormLinearParam(atb_speed::common::NormLinearParam &upNormLinearParam, + const MlpParam ¶m, bool isAntiOutlier) +{ + upNormLinearParam.isAntiOutlier = isAntiOutlier; + upNormLinearParam.fusionLinearParam.quantType = GetLinearQuantType( + param.packQuantType, param.layerLinearQuantType[UP_LINEAR_INDEX], param.enableNormQuantOp, + param.layerLinearDescs[UP_LINEAR_INDEX]); + upNormLinearParam.fusionLinearParam.isBF16 = param.isBF16; + upNormLinearParam.fusionLinearParam.hasBias = param.gateUpHasBias; + upNormLinearParam.fusionLinearParam.supportLora = param.supportLora; + upNormLinearParam.fusionLinearParam.useImMask = param.useImMask; + upNormLinearParam.fusionLinearParam.loraEnableGMM = param.loraEnableGMM; + upNormLinearParam.fusionLinearParam.transposeType = param.layerLinearTransposeType[UP_LINEAR_INDEX]; + upNormLinearParam.fusionLinearParam.quantGroupSize = param.quantGroupSize; + upNormLinearParam.fusionLinearParam.matmulBackend = param.matmulBackend; + upNormLinearParam.fusionLinearParam.isPrefill = param.isPrefill; + upNormLinearParam.skipNorm = param.skipNorm; + upNormLinearParam.normHasBias = param.normHasBias; + upNormLinearParam.normParamType = param.normParamType; + upNormLinearParam.normQuantParamType = param.normQuantParamType; + upNormLinearParam.fusionLinearParam.enableFlashComm = param.enableFlashComm; + upNormLinearParam.fusionLinearParam.flashCommParallelInfo.worldSize = + param.downLinearTensorParallelInfo.worldSize; + upNormLinearParam.fusionLinearParam.flashCommParallelInfo.rank = + param.downLinearTensorParallelInfo.rank; + upNormLinearParam.fusionLinearParam.flashCommParallelInfo.backend = + param.downLinearTensorParallelInfo.backend; +} + +template +atb::Status AddMlpNormLinearUp(const MlpParam ¶m, + bool isAntiOutlier, atb::GraphOpBuilder* &graphBuilder) +{ + atb::Operation* normLinearUpOp = nullptr; + atb_speed::common::NormLinearParam upNormLinearParam; + SetUpNormLinearParam(upNormLinearParam, param, isAntiOutlier); + CHECK_OPERATION_STATUS_RETURN(NormLinear(upNormLinearParam, &normLinearUpOp)); + + atb::SVector upInTensorNames = { + "in_input", "in_norm_weight", "in_norm_bias", "in_norm_new_weight", "in_norm_new_bias", + "in_weight_1", "in_scale_1", "in_offset_1", "in_descale_1", "in_bias_1", "in_compress_idx_1", + }; + if (param.supportLora) { + if (param.useImMask) { + upInTensorNames.push_back("in_im_mask"); + } + upInTensorNames.push_back("in_seq_len_cum_sum"); + upInTensorNames.push_back("in_lora_a_1"); + upInTensorNames.push_back("in_lora_b_1"); + } + if (param.enableFlashComm) { + upInTensorNames.push_back("send_counts"); + upInTensorNames.push_back("sdispls"); + upInTensorNames.push_back("send_count"); + upInTensorNames.push_back("recv_counts"); + upInTensorNames.push_back("rdispls"); + upInTensorNames.push_back("recv_count"); + upInTensorNames.push_back("fake_ag_shape"); + } + + graphBuilder->AddOperation(normLinearUpOp, upInTensorNames, {"intermediate_up"}); + return atb::NO_ERROR; +} + +atb::Status AddMlpSplit(atb::GraphOpBuilder* &graphBuilder) +{ + atb::Operation* splitOp = nullptr; + atb::infer::SplitParam splitParam; + splitParam.splitDim = -1; // [batchSize, seqLen, 2 * hiddenSize] + splitParam.splitNum = 2; // 进行二等分 + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(splitParam, &splitOp)); + + graphBuilder->AddOperation(splitOp, {"intermediate_gate_up"}, {"intermediate_gate", "intermediate_up"}); + return atb::NO_ERROR; +} + +atb::Status AddMlpSwiGLUConcat(atb::GraphOpBuilder* &graphBuilder) +{ + atb::Operation* concatOp = nullptr; + atb::infer::ConcatParam concatParam; + concatParam.concatDim = -1; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(concatParam, &concatOp)); + + graphBuilder->AddOperation(concatOp, {"intermediate_gate", "intermediate_up"}, {"intermediate_gate_up"}); + return atb::NO_ERROR; +} + +atb::Status AddMlpMul(atb::GraphOpBuilder* &graphBuilder) +{ + atb::Operation* mulOp = nullptr; + atb::infer::ElewiseParam elewiseParam; + elewiseParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(elewiseParam, &mulOp)); + + graphBuilder->AddOperation( + mulOp, {"intermediate_activation_out", "intermediate_up"}, {"intermediate_activation_out"}); + return atb::NO_ERROR; +} + +template +atb::Status AddMlpActivation(const MlpParam ¶m, atb::GraphOpBuilder* &graphBuilder) +{ + if (param.mlpPackType == MlpPackType::GATE_UP_WEIGHT_PACK) { + CHECK_OPERATION_STATUS_RETURN(AddMlpSplit(graphBuilder)); + } + + atb::Operation* activationOp = nullptr; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(param.activationParam, &activationOp)); + graphBuilder->AddOperation( + activationOp, + {param.mlpPackType == MlpPackType::UP_WEIGHT_ONLY ? "intermediate_up" : "intermediate_gate"}, + {"intermediate_activation_out"}); + + if (param.mlpPackType != MlpPackType::UP_WEIGHT_ONLY) { + CHECK_OPERATION_STATUS_RETURN(AddMlpMul(graphBuilder)); + } + return atb::NO_ERROR; +} + +template +atb::Status AddMlpEdgeActivation(const MlpParam ¶m, atb::GraphOpBuilder* &graphBuilder) +{ + if (param.mlpPackType == MlpPackType::GATE_UP_WEIGHT_PACK) { + CHECK_OPERATION_STATUS_RETURN(AddMlpSplit(graphBuilder)); + } + + atb::Operation* sigmoidOp = nullptr; + atb::infer::ActivationParam activationParam; + activationParam.activationType = atb::infer::ActivationType::ACTIVATION_SIGMOID; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(activationParam, &sigmoidOp)); + graphBuilder->AddOperation(sigmoidOp, {"intermediate_gate"}, {"intermediate_activation_out"}); + + atb::Operation* sigmoidMulOp = nullptr; + atb::infer::ElewiseParam elewiseParam; + elewiseParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(elewiseParam, &sigmoidMulOp)); + graphBuilder->AddOperation( + sigmoidMulOp, {"intermediate_gate", "intermediate_activation_out"}, {"intermediate_activation_out"}); + + if (param.mlpPackType != MlpPackType::UP_WEIGHT_ONLY) { + CHECK_OPERATION_STATUS_RETURN(AddMlpMul(graphBuilder)); + } + + return atb::NO_ERROR; +} + +template +atb::Status AddMlpSwiGLUActivation(const MlpParam ¶m, atb::GraphOpBuilder* &graphBuilder) +{ + if (param.mlpPackType == MlpPackType::GATE_UP_WEIGHT_NO_PACK) { + CHECK_OPERATION_STATUS_RETURN(AddMlpSwiGLUConcat(graphBuilder)); + } + + atb::Operation* activationOp = nullptr; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(param.activationParam, &activationOp)); + graphBuilder->AddOperation(activationOp, {"intermediate_gate_up"}, {"intermediate_activation_out"}); + return atb::NO_ERROR; +} + +template +void SetDownLinearParallelParam(const MlpParam ¶m, + atb_speed::common::LinearParallelParam &downLinearParallelParam) +{ + if (param.enableFlashComm) { + downLinearParallelParam.parallelType = atb_speed::common::REDUCE_SCATTER; + } else { + downLinearParallelParam.parallelType = atb_speed::common::ROW_PARALLEL; + } + downLinearParallelParam.fusionLinearParam.quantType = GetLinearQuantType( + param.downQuantType == atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED \ + ? param.packQuantType : param.downQuantType, + param.layerLinearQuantType[DOWN_LINEAR_INDEX], false, + param.layerLinearDescs[DOWN_LINEAR_INDEX]); + downLinearParallelParam.fusionLinearParam.isDownLinear = true; + downLinearParallelParam.fusionLinearParam.enableSwigluQuant = param.enableSwigluQuant; + bool downIsQuant = IsLinearDescQuant(param, DOWN_LINEAR_INDEX); + if (param.enableSwigluQuant && downIsQuant) { + if (param.isPrefill && downLinearParallelParam.fusionLinearParam.quantType == \ + atb_speed::common::LinearQuantType::LINEAR_W8A8_DYNAMIC_QUANT) { + downLinearParallelParam.fusionLinearParam.quantType = \ + atb_speed::common::LinearQuantType::LINEAR_W8A8_DYNAMIC_DEQUANT; + } else if (downLinearParallelParam.fusionLinearParam.quantType == \ + atb_speed::common::LinearQuantType::LINEAR_W8A8_QUANT) { + downLinearParallelParam.fusionLinearParam.quantType = \ + atb_speed::common::LinearQuantType::LINEAR_W8A8_DEQUANT; + } else if (downLinearParallelParam.fusionLinearParam.quantType == \ + atb_speed::common::LinearQuantType::LINEAR_W4A8_DYNAMIC_QUANT) { + downLinearParallelParam.fusionLinearParam.quantType = \ + atb_speed::common::LinearQuantType::LINEAR_W4A8_DYNAMIC_DEQUANT; + } + } + downLinearParallelParam.biasAfterSync = param.downLinearTensorParallelInfo.worldSize > 1 && \ + downLinearParallelParam.fusionLinearParam.quantType == atb_speed::common::LinearQuantType::NO_QUANT && \ + param.downHasBias; + downLinearParallelParam.fusionLinearParam.hasBias = param.downHasBias && !downLinearParallelParam.biasAfterSync; + downLinearParallelParam.fusionLinearParam.isBF16 = param.isBF16; + downLinearParallelParam.fusionLinearParam.supportLora = param.supportLora; + downLinearParallelParam.fusionLinearParam.useImMask = param.useImMask; + downLinearParallelParam.fusionLinearParam.loraEnableGMM = param.loraEnableGMM; + downLinearParallelParam.fusionLinearParam.transposeType = param.layerLinearTransposeType[DOWN_LINEAR_INDEX]; + downLinearParallelParam.fusionLinearParam.quantGroupSize = param.quantGroupSize; + downLinearParallelParam.fusionLinearParam.matmulBackend = param.matmulBackend; + downLinearParallelParam.fusionLinearParam.isPrefill = param.isPrefill; + downLinearParallelParam.tensorParallelInfo = param.downLinearTensorParallelInfo; + downLinearParallelParam.supportLcoc = param.supportLcoc; + downLinearParallelParam.enableMC2 = param.enableMC2; +} + +template +atb::Status AddMlpLinearDown(const MlpParam ¶m, atb::GraphOpBuilder* &graphBuilder) +{ + atb::Operation* linearDownOp = nullptr; + atb_speed::common::LinearParallelParam downLinearParallelParam; + SetDownLinearParallelParam(param, downLinearParallelParam); + CHECK_OPERATION_STATUS_RETURN(LinearParallel(downLinearParallelParam, &linearDownOp)); + + atb::SVector downInTensorNames = { + "intermediate_activation_out", + "in_weight_down", "in_scale_down", "in_offset_down", "in_descale_down", "in_bias_down", + "in_compress_idx_down" + }; + if (param.supportLora) { + if (param.useImMask) { + downInTensorNames.push_back("in_im_mask"); + } + downInTensorNames.push_back("in_seq_len_cum_sum"); + downInTensorNames.push_back("in_down_lora_a"); + downInTensorNames.push_back("in_down_lora_b"); + } + if (param.downLinearTensorParallelInfo.quantType != atb::infer::AllReduceParam::QuantType::QUANT_TYPE_UNDEFINED) { + downInTensorNames.push_back("in_reduce_quant_scale"); + downInTensorNames.push_back("in_reduce_quant_offset"); + downInTensorNames.push_back("in_gather_quant_scale"); + downInTensorNames.push_back("in_gather_quant_offset"); + } + if (param.isPrefill && param.enableSwigluQuant && downLinearParallelParam.fusionLinearParam.quantType == \ + atb_speed::common::LinearQuantType::LINEAR_W8A8_DYNAMIC_DEQUANT) { + downInTensorNames.push_back("intermediate_swiglu_dynamic_scale"); + } + if (downLinearParallelParam.parallelType == atb_speed::common::REDUCE_SCATTER) { + downInTensorNames.push_back("send_counts"); + downInTensorNames.push_back("sdispls"); + downInTensorNames.push_back("recv_count"); + downInTensorNames.push_back("fake_rs_shape"); + } + graphBuilder->AddOperation(linearDownOp, downInTensorNames, {"out_linear"}); + return atb::NO_ERROR; +} + +template +atb::Status AddDequantSwigluQuantNode(const MlpParam ¶m, atb::GraphOpBuilder* &graphBuilder) +{ + // 这里注意,不要根据layerLinearQuantType获取类型, 要根据layerLinearDescs获取 + LinearQuantType downQuantType = GetLinearQuantType( + param.downQuantType == atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED ? + param.packQuantType : param.downQuantType, + param.layerLinearQuantType[DOWN_LINEAR_INDEX], + false, + param.layerLinearDescs[DOWN_LINEAR_INDEX]); + bool gateUpIsQuant = IsLinearDescQuant(param, GATE_LINEAR_INDEX); + + AclNNDequantSwigluQuantParam aclnnParam; + aclnnParam.activateLeft = true; + aclnnParam.quantMode = "static"; + atb::SVector inTensorNames = { "intermediate_gate_up" }; // 0: x + FusionLinearParam linearParam; + linearParam.matmulBackend = param.matmulBackend; + linearParam.quantType = downQuantType; + if (gateUpIsQuant && UseQuantBatchMatmul(linearParam) && !param.isPrefill) { + inTensorNames.push_back("in_descale_0"); // 1: weightScaleOptional, fp32 + // 2: activationScaleOptional, 这个参数传null + inTensorNames.push_back("in_bias_0"); // 3: biasOptional, int32 + } + inTensorNames.push_back("in_scale_down"); // 4: quantScaleOptional + inTensorNames.push_back("in_offset_down"); // 5: quantOffsetOptional + if (downQuantType == atb_speed::common::LinearQuantType::LINEAR_W8A8_DYNAMIC_QUANT && param.isPrefill) { + inTensorNames = { "intermediate_gate_up" }; // 0: x + aclnnParam.quantMode = "dynamic"; // dynamic, 只传x + } + atb::SVector outTensorNames = {"intermediate_activation_out", "intermediate_swiglu_dynamic_scale"}; + aclnnParam.inTensorsNum = static_cast(inTensorNames.size()); + atb::Operation* dequantSwigluQuantOp = new atb_speed::common::DequantSwigluQuantOperation( + "aclNNDequantSwigluQuantNode", aclnnParam + ); + graphBuilder->AddOperation(dequantSwigluQuantOp, inTensorNames, outTensorNames); + return atb::NO_ERROR; +} + +template +atb::Status Mlp(const MlpParam ¶m, atb::Operation **operation) +{ + atb::GraphOpBuilder* graphOpBuilder = nullptr; + CHECK_OPERATION_STATUS_RETURN(CreateGraphOpBuilder(&graphOpBuilder)); + atb::Status res = CreateMlp(param, graphOpBuilder, operation, false); + if (DestroyGraphOpBuilder(graphOpBuilder) != atb::NO_ERROR) { + ATB_SPEED_LOG_WARN("Destroy graph builder failed. This may leads to memory leak, please check"); + } + return res; +} + +template +atb::Status MlpSwiGLU(const MlpParam ¶m, atb::Operation **operation) +{ + atb::GraphOpBuilder* graphOpBuilder = nullptr; + CHECK_OPERATION_STATUS_RETURN(CreateGraphOpBuilder(&graphOpBuilder)); + atb::Status res = CreateMlp(param, graphOpBuilder, operation, true); + if (DestroyGraphOpBuilder(graphOpBuilder) != atb::NO_ERROR) { + ATB_SPEED_LOG_WARN("Destroy graph builder failed. This may leads to memory leak, please check"); + } + return res; +} + +template +atb::Status CheckMlpParam(const MlpParam ¶m) +{ + if (param.layerLinearDescs.size() != 0 && \ + CheckParamVectorSize(param.layerLinearDescs, DOWN_LINEAR_INDEX + 1) != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR("The size of param.layerLinearDescs is wrong, please check"); + return atb::ERROR_INVALID_PARAM; + } + if (param.layerLinearQuantType.size() != 0 && \ + CheckParamVectorSize(param.layerLinearQuantType, DOWN_LINEAR_INDEX + 1) != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR("The size of param.layerLinearQuantType is wrong, please check"); + return atb::ERROR_INVALID_PARAM; + } + if (CheckParamVectorSize(param.layerLinearTransposeType, DOWN_LINEAR_INDEX + 1) != atb::NO_ERROR) { + ATB_SPEED_LOG_ERROR("The size of param.layerLinearTransposeType is wrong, please check"); + return atb::ERROR_INVALID_PARAM; + } + return atb::NO_ERROR; +} + +template +atb::Status CreateMlp( + const MlpParam ¶m, + atb::GraphOpBuilder* &graphOpBuilder, + atb::Operation **operation, bool isSwiGLU) +{ + bool isAntiOutlier = CheckAntiOutlier(param.packQuantType); + isAntiOutlier = isAntiOutlier || param.isAntiOutlier; + CHECK_OPERATION_STATUS_RETURN(CheckMlpParam(param)); + + std::string graphName = isSwiGLU ? "MlpSwiGLU" : "Mlp"; + if (param.mlpPackType == MlpPackType::GATE_UP_WEIGHT_PACK) { + graphName += "GateUpWeightPack"; + } else if (param.mlpPackType == MlpPackType::GATE_UP_WEIGHT_NO_PACK) { + graphName += "GateUpWeightNoPack"; + } else { + graphName += "UpWeightOnly"; + } + + atb::InferShapeFunc inferShapeFunc = [param](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + if (param.enableAddNorm) { outTensorDescs.at(1) = inTensorDescs.at(0); } + return atb::NO_ERROR; + }; + + CHECK_OPERATION_STATUS_RETURN(graphOpBuilder->Init( + graphName, inferShapeFunc, ConstructMlpInTensorList(param), ConstructMlpOutTensorList(param) + )); + + // Gate Up + CHECK_OPERATION_STATUS_RETURN(AddMlpNormLinearGateUp(param, isAntiOutlier, graphOpBuilder)); + if (param.mlpPackType == MlpPackType::GATE_UP_WEIGHT_NO_PACK) { + CHECK_OPERATION_STATUS_RETURN(AddMlpNormLinearUp(param, isAntiOutlier, graphOpBuilder)); + } + // Activation + if (param.isEdgeHardware) { + CHECK_OPERATION_STATUS_RETURN(AddMlpEdgeActivation(param, graphOpBuilder)); + } else if (isSwiGLU) { + bool downIsQuant = IsLinearDescQuant(param, DOWN_LINEAR_INDEX); + if (param.enableSwigluQuant && downIsQuant) { + CHECK_OPERATION_STATUS_RETURN(AddDequantSwigluQuantNode(param, graphOpBuilder)); + } else { + CHECK_OPERATION_STATUS_RETURN(AddMlpSwiGLUActivation(param, graphOpBuilder)); + } + } else { + CHECK_OPERATION_STATUS_RETURN(AddMlpActivation(param, graphOpBuilder)); + } + // Down + CHECK_OPERATION_STATUS_RETURN(AddMlpLinearDown(param, graphOpBuilder)); + + *operation = graphOpBuilder->Build(); + return atb::NO_ERROR; +} + +MlpPackType GetMlpPackType( + const int &packQuantType, bool upWeightOnly, const std::vector &linearDescs) +{ + if (upWeightOnly) { + return atb_speed::common::UP_WEIGHT_ONLY; + } + std::vector gateUpLinearIndex = {GATE_LINEAR_INDEX, UP_LINEAR_INDEX}; + bool isPack = CheckPack(packQuantType, linearDescs, gateUpLinearIndex); + if (isPack) { + return atb_speed::common::GATE_UP_WEIGHT_PACK; + } else { + return atb_speed::common::GATE_UP_WEIGHT_NO_PACK; + } +} + +template +bool IsLinearDescQuant(const MlpParam ¶m, const uint64_t index) +{ + return param.layerLinearDescs[index] != common::LinearDesc::INVALID_DESC && \ + param.layerLinearDescs[index] != common::LinearDesc::FLOAT16_DESC && \ + param.layerLinearDescs[index] != common::LinearDesc::BFLOAT16_DESC; +} + +template bool IsLinearDescQuant(const MlpParam ¶m, const uint64_t index); + +template bool IsLinearDescQuant(const MlpParam ¶m, const uint64_t index); + +template void SetDownLinearParallelParam(const MlpParam ¶m, + atb_speed::common::LinearParallelParam &downLinearParallelParam); + +template void SetDownLinearParallelParam(const MlpParam ¶m, + atb_speed::common::LinearParallelParam &downLinearParallelParam); + +template atb::Status CheckMlpParam(const MlpParam ¶m); + +template atb::Status CheckMlpParam(const MlpParam ¶m); + +template atb::SVector ConstructMlpInTensorList(const MlpParam ¶m); + +template atb::SVector ConstructMlpInTensorList(const MlpParam ¶m); + +template atb::SVector ConstructMlpOutTensorList(const MlpParam ¶m); + +template atb::SVector ConstructMlpOutTensorList(const MlpParam ¶m); + +template void SetGateUpNormLinearParam( + atb_speed::common::NormLinearParam &gateUpNormLinearParam, + const MlpParam ¶m, bool isAntiOutlier); + +template void SetGateUpNormLinearParam( + atb_speed::common::NormLinearParam &gateUpNormLinearParam, + const MlpParam ¶m, bool isAntiOutlier); + +template atb::Status AddMlpNormLinearGateUp(const MlpParam ¶m, + bool isAntiOutlier, atb::GraphOpBuilder* &graphBuilder); + +template atb::Status AddMlpNormLinearGateUp(const MlpParam ¶m, + bool isAntiOutlier, atb::GraphOpBuilder* &graphBuilder); + +template atb::Status AddMlpLinearDown(const MlpParam ¶m, + atb::GraphOpBuilder* &graphBuilder); + +template atb::Status AddMlpLinearDown(const MlpParam ¶m, + atb::GraphOpBuilder* &graphBuilder); + +template atb::Status AddMlpNormLinearUp(const MlpParam ¶m, + bool isAntiOutlier, atb::GraphOpBuilder* &graphBuilder); + +template atb::Status AddMlpNormLinearUp(const MlpParam ¶m, + bool isAntiOutlier, atb::GraphOpBuilder* &graphBuilder); + +template atb::Status AddMlpActivation(const MlpParam ¶m, + atb::GraphOpBuilder* &graphBuilder); + +template atb::Status AddMlpActivation(const MlpParam ¶m, + atb::GraphOpBuilder* &graphBuilder); + +template atb::Status AddMlpEdgeActivation(const MlpParam ¶m, + atb::GraphOpBuilder* &graphBuilder); + +template atb::Status AddMlpSwiGLUActivation(const MlpParam ¶m, + atb::GraphOpBuilder* &graphBuilder); + +template atb::Status AddMlpSwiGLUActivation(const MlpParam ¶m, + atb::GraphOpBuilder* &graphBuilder); + +template atb::Status AddDequantSwigluQuantNode(const MlpParam ¶m, + atb::GraphOpBuilder* &graphBuilder); + +template atb::Status AddDequantSwigluQuantNode(const MlpParam ¶m, + atb::GraphOpBuilder* &graphBuilder); + +template atb::Status Mlp(const MlpParam ¶m, atb::Operation **operation); + +template atb::Status Mlp(const MlpParam ¶m, atb::Operation **operation); + +template atb::Status MlpSwiGLU(const MlpParam ¶m, atb::Operation **operation); + +template atb::Status MlpSwiGLU(const MlpParam ¶m, atb::Operation **operation); + +template atb::Status CreateMlp( + const MlpParam ¶m, + atb::GraphOpBuilder* &graphOpBuilder, atb::Operation **operation, bool isSwiGLU); + +template atb::Status CreateMlp( + const MlpParam ¶m, + atb::GraphOpBuilder* &graphOpBuilder, atb::Operation **operation, bool isSwiGLU); + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/mlp/mlp.h b/tests/proftest/layer_test_framework/operations/fusion/mlp/mlp.h new file mode 100644 index 00000000..53a2712b --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/mlp/mlp.h @@ -0,0 +1,234 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_COMMON_MLP_OPERATION_H +#define ATB_SPEED_MODELS_COMMON_MLP_OPERATION_H +#include +#include +#include "atb_speed/utils/operation_util.h" +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/linear/linear_parallel.h" +#include "operations/fusion/norm/norm_linear.h" + +namespace atb_speed { +namespace common { + +/// The categories of the mlp module's input tensors +/// Input tensors will be arragned according to the order of their categories +enum MlpInTensorCategory : unsigned int { + /// Default tensors + MLP_DEFAULT = 0, + /// Tensors required by addRmsNormQuant, addRmsNormDynamicQuant + MLP_ADD_RMS_NORM_QUANT, + /// Tensors required by the add norm fusion operation + MLP_ADD_NORM, + /// The mask tensor before applying lora adapters + MLP_LORA_MASK, + /// Tensors needed for LoRA + MLP_LORA, + /// Tensors required by the quantization of the all reduce operation + MLP_REDUCE_QUANT, + /// Tensors needed for Flashcomm1.0 + MLP_FC, + /// A flag signifying the end of all categories + MLP_END +}; + +/// The pack type of the gate and up linear +enum MlpPackType : unsigned int { + /// The gate and up linear is packed + GATE_UP_WEIGHT_PACK = 0, + /// The gate and up linear is not packed + GATE_UP_WEIGHT_NO_PACK = 1, + /// No gate linear + UP_WEIGHT_ONLY = 2, +}; + +/// The index of the gate linear within the layer +const uint64_t GATE_LINEAR_INDEX = 4; +/// The index of the up linear within the layer +const uint64_t UP_LINEAR_INDEX = 5; +/// The index of the down linear within the layer +const uint64_t DOWN_LINEAR_INDEX = 6; + +/// Parameters for the mlp module +/// \tparam NormParamType Types of the normalization parameters. Avaliable options are `atb::infer::RmsNormParam` +/// and `atb::infer::LayerNormParam`. +template +struct MlpParam { + /// When `isBF16` is true, bfloat16 precision is used; otherwise, float16 precision is used. + bool isBF16 = false; + /// A flag indicating the prefill and decode phases + bool isPrefill = false; + bool isEdgeHardware = false; + /// A flag indicating whether gate and up linear has bias + bool gateUpHasBias = false; + /// A flag indicating whether down linear has bias + bool downHasBias = false; + /// A flag that indicates whether low-latency computation over communication is enabled + bool supportLcoc = false; + bool enableMC2 = false; + /// A flag indicating whether normalization is skipped + bool skipNorm = false; + /// A flag indicating whether normalization has bias + bool normHasBias = false; + /// A flag indicating whether to use the AddNorm fusion operation + bool enableAddNorm = false; + /// A flag indicating whether to use NormQuant fusion operation + bool enableNormQuantOp = true; + /// A flag indicating whether lora is enabled. + bool supportLora = false; + /// A flag indicating whether a mask is used before applying lora adapter. + bool useImMask = false; + /// it should be activated when batch inputs include multiple LoRA adapters + bool loraEnableGMM = false; + /// A flag indicating whether to use swigluQuant + bool enableSwigluQuant = false; + /// A flag indicating whether to use flashcomm 1.0 + bool enableFlashComm = false; + /// The pack type of the gate and up linear. Refer to `MlpPackType` in the `operations/mlp/mlp.h`. + MlpPackType mlpPackType = GATE_UP_WEIGHT_PACK; + /// Specifies the quantization type for the following linear module: + /// q linear, k linear, v linear, dense linear, gate linear, up linear, and down linear. + std::vector layerLinearQuantType = {}; + /// Specifies the weight description of the following linear module: + /// qkv linear, dense linear, gateup linear and down linear. + std::vector layerLinearDescs = { + common::LinearDesc::INVALID_DESC, common::LinearDesc::INVALID_DESC, common::LinearDesc::INVALID_DESC, + common::LinearDesc::INVALID_DESC, common::LinearDesc::INVALID_DESC, common::LinearDesc::INVALID_DESC, + common::LinearDesc::INVALID_DESC + }; + /// Defines the transpose type of the second matrix in the matmul operation for the following linear module: + /// q linear, k linear, v linear, dense linear, gate linear, up linear, and down linear. + std::vector layerLinearTransposeType = {}; + /// Indicates the pack type and the quantization type of the gate up linear. + int packQuantType = atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED; + /// The group size used for dequantizing the weight tensor in the per-group quantization approach + int quantGroupSize = 0; + /// Normalization parameters for float operation + NormParamType normParamType; + /// Normlization parameters for quantization operation + NormParamType normQuantParamType; + /// Parameters for the activation operation + atb::infer::ActivationParam activationParam; + /// The quantization type of the down linear. Refer to `PackQuantType` in the `operations/utils.h`. + int downQuantType = atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED; + /// Details about tensor parallelism + atb_speed::common::TensorParallelInfo downLinearTensorParallelInfo; + /// A flag indicating whether to use the atb matmul backend + int matmulBackend = atb_speed::common::OpBackend::ATB; + /// Specifies whether the post attention norm enables antioutlier + bool isAntiOutlier = false; +}; + +/// Get the `MlpPackType` based on the quantizaton type of the gate-up linear and the structure of the model. +/// \param packQuantType Parameters to determin whether the gate-up linear is packed. +/// Refer to `PackQuantType` in the `operations/utils.h`. +/// \param upWeightOnly A flag indicating if the structure of the layer only has up linear. +/// \param linearDescs weight description of linear module +/// \return Refer to `MlpPackType` in the `operations/fusion/mlp.h`. +MlpPackType GetMlpPackType( + const int &packQuantType = PackQuantType::PACK_QUANT_UNDEFINED, + bool upWeightOnly = false, + const std::vector &linearDescs = {}); + +/// The mlp module. +/// It consists of the following operations: NormLinear operations for the gate-up linear, +/// Split operation if the gate-up linear is packed, Activation operation, Elementwise mul operation +/// and LinearParallel operation for the down linear. +/// +/// \tparam NormParamType Types of the normalization parameters. Avaliable options are `atb::infer::RmsNormParam` +/// and `atb::infer::LayerNormParam`. +/// \param param Parameters for the mlp module +/// \param operation The address of a pointer to a default operation +/// \return A flag that indicates whether operation has been successfully created. +/// +/// Operation's inputs: +/// Name | Requirements | Dtype | Shape | Description | +/// -----------------------|--------------|------------------|-------|----------| +/// in_input | Required |float16/bfloat16 | paged attention: [len(all_seq),hidden_size] | Hidden states | +/// ^ | ^ | ^ | flash attention: [bsz,seq_len,hidden_size] | ^ | +/// in_norm_weight | ^ | Refer to `NormLinear` in the `operations/fusion/norm/norm_linear.h` for more details. ||| +/// in_norm_bias | ^ | ^ ||| +/// in_norm_new_weight | ^ | ^ ||| +/// in_norm_new_bias | ^ | ^ ||| +/// in_weight_0 | ^ | If gate-up linear are packed, these are the concatenated gate-up weights.
If the gate-up linear is unpacked, these are weights for the gate linear.
Weights for the up linear will be passed in if the layer only has up weights
Refer to `NormLinear` in the `operations/fusion/norm/norm_linear.h` for more details. ||| +/// in_scale_0 | ^ | ^ ||| +/// in_offset_0 | ^ | ^ ||| +/// in_descale_0 | ^ | ^ ||| +/// in_bias_0 | ^ | ^ ||| +/// in_compress_idx_0 | ^ | ^ ||| +/// in_weight_1 | ^ | If gate-up linear are not packed, these are weights for the up linear operation; otherwise, placeholders should be provided. ||| +/// in_scale_1 | ^ | ^ ||| +/// in_offset_1 | ^ | ^ ||| +/// in_descale_1 | ^ | ^ ||| +/// in_bias_1 | ^ | ^ ||| +/// in_compress_idx_1 | ^ | ^ ||| +/// in_weight_down | ^ | Weights for the dense linear operation. ||| +/// in_scale_down | ^ | ^ ||| +/// in_offset_down | ^ | ^ ||| +/// in_descale_down | ^ | ^ ||| +/// in_bias_down | ^ | ^ ||| +/// in_compress_idx_down | ^ | ^ ||| +/// in_residual_add | `param.enableAddNorm` is true | The same as in_input | The same as in_input | | +/// in_im_mask | when `param.supportLora` and `param.useImMask` are true | Refer to `FusionLinear` in the `operations/fusion/linear/linear.h`. ||| +/// in_seq_len_cum_sum | `param.supportLora` is true | Refer to `FusionLinear` in the `operations/fusion/linear/linear.h`. ||| +/// in_lora_a_0 | ^ | ^ ||| +/// in_lora_b_0 | ^ | ^ ||| +/// in_lora_a_1 | ^ | ^ ||| +/// in_lora_b_1 | ^ | ^ ||| +/// in_down_lora_a | ^ | ^ ||| +/// in_down_lora_b | ^ | ^ ||| +/// in_reduce_quant_scale | `param.tensorParallelInfo.quantType == atb::infer::AllReduceParam::QuantType::QUANT_TYPE_PER_CHANNEL` | Refer to `LinearParallel` in the `operations/fusion/linear/linear_parallel.h`. ||| +/// in_reduce_quant_offset | ^ | ^ ||| +/// in_gather_quant_scale | ^ | ^ ||| +/// in_gather_quant_offset | ^ | ^ ||| +/// +/// Operation's Outputs: +/// Name | Dtype | Shape | Description | +/// -----------|---------------------|-------|----------| +/// out_linear | The same as in_input| The same as in_input | Output tensor of the mlp module | +/// out_add | The same as in_input | The same as in_input | The tensor resulting from adding the input and output of the attention module in a residual connection. Exist when `enableAddNorm` is true. | +/// +/// Example: +/// \code +/// atb::Node mlpNode; +/// atb_speed::common::MlpParam mlpParam; +/// // Modify mlpParam's attribute if needed. +/// Mlp(mlpParam, &mlpNode.operation); +/// mlpNode.inTensorIds = {...}; // Passing inputs for the operation in order +/// mlpNode.outTensorIds = {...}; // Tensor index for out +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(mlpNode); +/// \endcode +template +atb::Status Mlp(const MlpParam ¶m, atb::Operation **operation); + +/// The mlp module implemented with the SwiGLU fusion operation. +/// The SwiGLU fusion operation processes the combined output of the gate and up linear operations as input, +/// intergrating the activation and multiplication operations into a single operation. +/// +/// \tparam NormParamType Types of the normalization parameters. Avaliable options are `atb::infer::RmsNormParam` +/// and `atb::infer::LayerNormParam`. +/// \param param Parameters for the mlp module +/// \param operation The address of a pointer to a default operation +/// \return A flag that indicates whether operation has been successfully created. +/// +/// Inputs and outputs adhere to the same specification as those of the `Mlp`. +template +atb::Status MlpSwiGLU(const MlpParam ¶m, atb::Operation **operation); +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/mlp_gate.cpp b/tests/proftest/layer_test_framework/operations/fusion/mlp_gate.cpp new file mode 100644 index 00000000..1ff5b705 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/mlp_gate.cpp @@ -0,0 +1,151 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "atb_speed/log.h" +#include "parallel_layer.h" +#include "mlp_gate.h" + +namespace atb_speed { +namespace common { +template atb::Status AddmatmulUpNode(atb::Node &matmulUpNode, const MlpGateParam ¶m, T &config) +{ + atb::infer::LinearParam matmulUpParam = { false, param.transposeB, param.isBias }; + CREATE_OPERATION(matmulUpParam, &matmulUpNode.operation); + if (param.isBias) { + matmulUpNode.inTensorIds = { config.IN_HIDDENSTATES_ID, config.IN_WEIGHT_UP_ID, config.IN_BIAS_UP_ID }; + } else { + matmulUpNode.inTensorIds = { config.IN_HIDDENSTATES_ID, config.IN_WEIGHT_UP_ID }; + } + matmulUpNode.outTensorIds = { config.INTERMEDIATE_MATMUL_UP_OUT_ND_ID }; + return atb::NO_ERROR; +} + + +template +atb::Status AddsplitNode(atb::Node &splitNode, atb::Node &matmulGateNode, const MlpGateParam ¶m, T &config, + atb::GraphParam &opGraph) +{ + if (param.isPack) { + atb::infer::SplitParam splitParam; + splitParam.splitDim = -1; // 2: split最后一维 + splitParam.splitNum = 2; // 2: 进行二等分 + CREATE_OPERATION(splitParam, &splitNode.operation); + splitNode.inTensorIds = { config.INTERMEDIATE_MATMUL_UP_OUT_ND_ID }; + splitNode.outTensorIds = { config.INTERMEDIATE_MATMUL_GATE_OUT_ND_ID, config.INTERMEDIATE_SPLIT_OUT_ND_ID }; + opGraph.nodes.push_back(splitNode); + } else { + atb::infer::LinearParam matmulGateParam = { false, param.transposeB, param.isBias }; + CREATE_OPERATION(matmulGateParam, &matmulGateNode.operation); + if (param.isBias) { + matmulGateNode.inTensorIds = { config.IN_HIDDENSTATES_ID, config.IN_WEIGHT_GATE_ID, + config.IN_BIAS_GATE_ID }; + } else { + matmulGateNode.inTensorIds = { config.IN_HIDDENSTATES_ID, config.IN_WEIGHT_GATE_ID }; + } + matmulGateNode.outTensorIds = { config.INTERMEDIATE_MATMUL_GATE_OUT_ND_ID }; + opGraph.nodes.push_back(matmulGateNode); + } + return atb::NO_ERROR; +} + +template atb::Status AddactNode(atb::Node &actNode, const MlpGateParam ¶m, T &config) +{ + atb::infer::ActivationParam actParam; + actParam.activationType = param.activationType; + CREATE_OPERATION(actParam, &actNode.operation); + actNode.inTensorIds = { config.INTERMEDIATE_MATMUL_GATE_OUT_ND_ID }; + actNode.outTensorIds = { config.INTERMEDIATE_ACTIVATION_OUT_ID }; + return atb::NO_ERROR; +} + +template atb::Status AddmulNode(atb::Node &mulNode, const MlpGateParam ¶m, T &config) +{ + atb::infer::ElewiseParam mulParam; + mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; + CREATE_OPERATION(mulParam, &mulNode.operation); + if (param.isPack) { + mulNode.inTensorIds = { config.INTERMEDIATE_ACTIVATION_OUT_ID, config.INTERMEDIATE_SPLIT_OUT_ND_ID }; + } else { + mulNode.inTensorIds = { config.INTERMEDIATE_ACTIVATION_OUT_ID, config.INTERMEDIATE_MATMUL_UP_OUT_ND_ID }; + } + mulNode.outTensorIds = { config.INTERMEDIATE_MUL_OUT_ID }; + return atb::NO_ERROR; +} + +template atb::Status AddmatmulDownNode(atb::Node &matmulDownNode, const MlpGateParam ¶m, T &config) +{ + atb_speed::common::ParallelParam linearParallelParam = { param.rank, param.rankSize, 0, + nullptr, param.isBias, false, + param.transposeB, param.backend, param.isBF16 }; + + atb_speed::common::RowParallelLinear(linearParallelParam, &matmulDownNode.operation); + if (param.isBias) { + matmulDownNode.inTensorIds = { config.INTERMEDIATE_MUL_OUT_ID, config.IN_WEIGHT_DOWN_ID, + config.IN_BIAS_DOWN_ID }; + } else { + matmulDownNode.inTensorIds = { config.INTERMEDIATE_MUL_OUT_ID, config.IN_WEIGHT_DOWN_ID }; + } + matmulDownNode.outTensorIds = { config.OUT_RESULT_ID }; + return atb::NO_ERROR; +} + +template atb::Status MlpGateLayerBase(const MlpGateParam ¶m, atb::Operation **operation, T config) +{ + atb::GraphParam opGraph; + opGraph.name = "MlpGateLayerBase"; + opGraph.inTensorNum = static_cast(config.inTensorNum); + opGraph.outTensorNum = static_cast(config.outTensorNum); + opGraph.internalTensorNum = static_cast(config.interTensorNum); + atb::Node matmulUpNode; + atb::Node splitNode; + atb::Node matmulGateNode; + atb::Node actNode; + atb::Node mulNode; + atb::Node matmulDownNode; + CHECK_OPERATION_STATUS_RETURN(AddmatmulUpNode(matmulUpNode, param, config)); + opGraph.nodes.push_back(matmulUpNode); + CHECK_OPERATION_STATUS_RETURN(AddsplitNode(splitNode, matmulGateNode, param, config, opGraph)); + CHECK_OPERATION_STATUS_RETURN(AddactNode(actNode, param, config)); + opGraph.nodes.push_back(actNode); + CHECK_OPERATION_STATUS_RETURN(AddmulNode(mulNode, param, config)); + opGraph.nodes.push_back(mulNode); + CHECK_OPERATION_STATUS_RETURN(AddmatmulDownNode(matmulDownNode, param, config)); + opGraph.nodes.push_back(matmulDownNode); + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + return atb::NO_ERROR; + }; + CREATE_OPERATION(opGraph, operation); + return atb::NO_ERROR; +} + + +atb::Status MlpGateLayer(const MlpGateParam ¶m, atb::Operation **operation) +{ + if (param.isBias && param.isPack) { + return MlpGateLayerBase(param, operation, MlpGateWithPackAndBias(5, 1, 5, 5)); // 5:in 1:out 5:inter 5:node + } else if (param.isBias) { + return MlpGateLayerBase(param, operation, MlpGateWithBias(7, 1, 4, 5)); // 7:in 1:out 4:inter 5:node + } else if (param.isPack) { + return MlpGateLayerBase(param, operation, MlpGateWithPack(3, 1, 5, 5)); // 3:in 1:out 5:inter 5:node + } else { + return MlpGateLayerBase(param, operation, MlpGate(4, 1, 4, 5)); // 4:in 1:out 4:inter 5:node + } +} +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/fusion/mlp_gate.h b/tests/proftest/layer_test_framework/operations/fusion/mlp_gate.h new file mode 100644 index 00000000..81e68a59 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/mlp_gate.h @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_LAYERS_MLP_GATE_H +#define ATB_SPEED_LAYERS_MLP_GATE_H + +#include +#include "atb_speed/log.h" +#include "atb_speed/utils/operation_util.h" +#include "nlohmann/json.hpp" + +#include "common_op_base.h" + +namespace atb_speed { +namespace common { +class MlpGateWithBias : public CommonOpBase { +public: + using CommonOpBase::CommonOpBase; + + enum MlpGateWithBiasId : int { + IN_HIDDENSTATES_ID = 0, // [batch, seqLen, hiddenSize], half + IN_WEIGHT_UP_ID, // [hiddenSize, ffnHiddenSize], half + IN_WEIGHT_GATE_ID, // [hiddenSize, ffnHiddenSize], half + IN_WEIGHT_DOWN_ID, // [ffnHiddenSize, hiddenSize], half + IN_BIAS_UP_ID, // + IN_BIAS_GATE_ID, // + IN_BIAS_DOWN_ID, // + OUT_RESULT_ID, // [batch, seqLen, hiddenSize], half + INTERMEDIATE_MATMUL_GATE_OUT_ND_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_MATMUL_UP_OUT_ND_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_ACTIVATION_OUT_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_MUL_OUT_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_SPLIT_OUT_ND_ID, + }; +}; + +class MlpGate : public CommonOpBase { +public: + using CommonOpBase::CommonOpBase; + + enum MlpGateId : int { + IN_HIDDENSTATES_ID = 0, // [batch, seqLen, hiddenSize], half + IN_WEIGHT_UP_ID, // [hiddenSize, ffnHiddenSize], half + IN_WEIGHT_GATE_ID, // [hiddenSize, ffnHiddenSize], half + IN_WEIGHT_DOWN_ID, // [ffnHiddenSize, hiddenSize], half + OUT_RESULT_ID, // [batch, seqLen, hiddenSize], half + INTERMEDIATE_MATMUL_GATE_OUT_ND_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_MATMUL_UP_OUT_ND_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_ACTIVATION_OUT_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_MUL_OUT_ID, // [batch, seqLen, ffnHiddenSize], half + IN_BIAS_UP_ID, // no need + IN_BIAS_GATE_ID, // no need + IN_BIAS_DOWN_ID, // no need + INTERMEDIATE_SPLIT_OUT_ND_ID, + }; +}; + +class MlpGateWithPackAndBias : public CommonOpBase { +public: + using CommonOpBase::CommonOpBase; + + enum MlpGateWithPackId : int { + IN_HIDDENSTATES_ID = 0, // [batch, seqLen, hiddenSize], half + IN_WEIGHT_UP_ID, // [hiddenSize, ffnHiddenSize], half + IN_WEIGHT_DOWN_ID, // [ffnHiddenSize, hiddenSize], half + IN_BIAS_UP_ID, + IN_BIAS_DOWN_ID, + OUT_RESULT_ID, // [batch, seqLen, hiddenSize], half + INTERMEDIATE_MATMUL_GATE_OUT_ND_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_MATMUL_UP_OUT_ND_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_ACTIVATION_OUT_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_MUL_OUT_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_SPLIT_OUT_ND_ID, // [batch, seqLen, ffnHiddenSize], half + IN_WEIGHT_GATE_ID, // no need + IN_BIAS_GATE_ID, // no need + }; +}; + +class MlpGateWithPack : public CommonOpBase { +public: + using CommonOpBase::CommonOpBase; + + enum MlpGateWithPackId : int { + IN_HIDDENSTATES_ID = 0, // [batch, seqLen, hiddenSize], half + IN_WEIGHT_UP_ID, // [hiddenSize, ffnHiddenSize], half + IN_WEIGHT_DOWN_ID, // [ffnHiddenSize, hiddenSize], half + OUT_RESULT_ID, // [batch, seqLen, hiddenSize], half + INTERMEDIATE_MATMUL_GATE_OUT_ND_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_MATMUL_UP_OUT_ND_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_ACTIVATION_OUT_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_MUL_OUT_ID, // [batch, seqLen, ffnHiddenSize], half + INTERMEDIATE_SPLIT_OUT_ND_ID, // [batch, seqLen, ffnHiddenSize], half + IN_WEIGHT_GATE_ID, // no need + IN_BIAS_UP_ID, // no need + IN_BIAS_GATE_ID, // no need + IN_BIAS_DOWN_ID, // no need + }; +}; + +struct MlpGateParam { + int rank = 0; + int rankSize = 1; + int rankRoot = 0; + void *hcclComm = nullptr; + atb::infer::ActivationType activationType; + bool transposeB = true; + bool isBias = false; + bool isPack = false; + std::string backend = "hccl"; + bool isBF16 = false; +}; + +atb::Status MlpGateLayer(const MlpGateParam ¶m, atb::Operation **operation); +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/fusion/mlp_gate_v2.cpp b/tests/proftest/layer_test_framework/operations/fusion/mlp_gate_v2.cpp new file mode 100644 index 00000000..ef72093c --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/mlp_gate_v2.cpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mlp_gate_v2.h" + +#include + +#include "parallel_layer_v2.h" + +namespace atb_speed { +namespace common { +enum InTensorId : uint32_t { + IN_HIDDENSTATES_ID = 0, // [batch, seqLen, hiddenSize], half + IN_WEIGHT_UP_ID, // [hiddenSize, ffnHiddenSize], half + IN_WEIGHT_GATE_ID, // [hiddenSize, ffnHiddenSize], half + IN_WEIGHT_DOWN_ID, // [ffnHiddenSize, hiddenSize], half + IN_DEQSCALE_UP, // quant scale up + IN_DEQSCALE_GATE, // quant scale gete + IN_DEQSCALE_DOWN, // quant scale down + IN_BIAS_UP_ID, + IN_BIAS_GATE_ID, + IN_BIAS_DOWN_ID, + + IN_INDEX_UP, + IN_INDEX_GATE, + IN_INDEX_DOWN, + IN_OFFSETX_UP, + IN_OFFSETX_GATE, + IN_OFFSETX_DOWN, + IN_COMPRESSINFO_UP, + IN_COMPRESSINFO_GATE, + IN_COMPRESSINFO_DOWN, + + OUT_RESULT_ID, + INTERMEDIATE_MATMUL_UP_OUT_ID, + INTERMEDIATE_ACTIVATION_OUT_ID, + INTERMEDIATE_MUL_OUT_ID, + INTERMEDIATE_MATMUL_OUT_ID, + INTERMEDIATE_SPLIT_OUT_ID, +}; + +atb::Status AddmatmulUpV2Node(atb::Node &matmulUpNode, const MlpGateParamV2 ¶m) +{ + atb_speed::common::ParallelParamV2 linearUpParam; + linearUpParam.isBias = param.isBias; + linearUpParam.transposeA = false; + linearUpParam.transposeB = param.transposeB; + linearUpParam.isQuant = param.isQuant; + linearUpParam.isSparse = param.isSparse; + linearUpParam.isAllGatherTranspose = false; + linearUpParam.isBF16 = param.isBF16; + + linearUpParam.quantParam = param.quantUpParam; + atb_speed::common::RowParallelLinearV2(linearUpParam, &matmulUpNode.operation); + matmulUpNode.inTensorIds = { IN_HIDDENSTATES_ID, IN_WEIGHT_UP_ID, IN_BIAS_UP_ID, IN_DEQSCALE_UP, + IN_INDEX_UP, IN_OFFSETX_UP, IN_COMPRESSINFO_UP }; + matmulUpNode.outTensorIds = { INTERMEDIATE_MATMUL_UP_OUT_ID }; + return atb::NO_ERROR; +} + +atb::Status AddmatmulGateV2Node(atb::Node &splitNode, atb::Node &matmulGateNode, const MlpGateParamV2 ¶m, + atb::GraphParam &opGraph) +{ + if (!param.noGate) { + if (param.isPack) { + atb::infer::SplitParam splitParam; + splitParam.splitDim = -1; // 2: [bs, seq, 2*hidden_size] + splitParam.splitNum = 2; // 2: 进行二等分 + CREATE_OPERATION(splitParam, &splitNode.operation); + splitNode.inTensorIds = { INTERMEDIATE_MATMUL_UP_OUT_ID }; + splitNode.outTensorIds = { INTERMEDIATE_MATMUL_OUT_ID, INTERMEDIATE_SPLIT_OUT_ID }; + opGraph.nodes.push_back(splitNode); + } else { + atb_speed::common::ParallelParamV2 linearGateParam; + linearGateParam.isBias = param.isBias; + linearGateParam.transposeA = false; + linearGateParam.transposeB = param.transposeB; + linearGateParam.isQuant = param.isQuant; + linearGateParam.isSparse = param.isSparse; + linearGateParam.isAllGatherTranspose = false; + linearGateParam.isBF16 = param.isBF16; + + linearGateParam.quantParam = param.quantGateParam; + atb_speed::common::RowParallelLinearV2(linearGateParam, &matmulGateNode.operation); + matmulGateNode.inTensorIds = { IN_HIDDENSTATES_ID, IN_WEIGHT_GATE_ID, IN_BIAS_GATE_ID, IN_DEQSCALE_GATE, + IN_INDEX_GATE, IN_OFFSETX_GATE, IN_COMPRESSINFO_GATE }; + matmulGateNode.outTensorIds = { INTERMEDIATE_MATMUL_OUT_ID }; + opGraph.nodes.push_back(matmulGateNode); + } + } + return atb::NO_ERROR; +} + +atb::Status AddactV2Node(atb::Node &actNode, const MlpGateParamV2 ¶m) +{ + atb::infer::ActivationParam actParam; + actParam.activationType = param.activationType; + CREATE_OPERATION(actParam, &actNode.operation); + actNode.inTensorIds = { param.noGate ? INTERMEDIATE_MATMUL_UP_OUT_ID : INTERMEDIATE_MATMUL_OUT_ID }; + actNode.outTensorIds = { INTERMEDIATE_ACTIVATION_OUT_ID }; + return atb::NO_ERROR; +} + +atb::Status AddmulV2Node(atb::Node &mulNode, const MlpGateParamV2 ¶m, atb::GraphParam &opGraph) +{ + if (!param.noGate) { + atb::infer::ElewiseParam mulParam; + mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; + CREATE_OPERATION(mulParam, &mulNode.operation); + if (param.isPack) { + mulNode.inTensorIds = { INTERMEDIATE_ACTIVATION_OUT_ID, INTERMEDIATE_SPLIT_OUT_ID }; + } else { + mulNode.inTensorIds = { INTERMEDIATE_ACTIVATION_OUT_ID, INTERMEDIATE_MATMUL_UP_OUT_ID }; + } + mulNode.outTensorIds = { INTERMEDIATE_MUL_OUT_ID }; + opGraph.nodes.push_back(mulNode); + } + return atb::NO_ERROR; +} + +atb::Status AddmatmulDownV2Node(atb::Node &matmulDownNode, const MlpGateParamV2 ¶m) +{ + atb_speed::common::ParallelParamV2 linearDownParam; + linearDownParam.isBias = param.isBias; + linearDownParam.transposeA = false; + linearDownParam.transposeB = param.transposeB; + linearDownParam.isQuant = param.isQuant; + linearDownParam.isSparse = param.isSparse; + linearDownParam.isAllGatherTranspose = false; + linearDownParam.isBF16 = param.isBF16; + + linearDownParam.commParam = param.commDownParam; + linearDownParam.quantParam = param.quantDownParam; + linearDownParam.quantParam.isQuantOp = true; + atb_speed::common::RowParallelLinearV2(linearDownParam, &matmulDownNode.operation); + matmulDownNode.inTensorIds = { param.noGate ? INTERMEDIATE_ACTIVATION_OUT_ID : INTERMEDIATE_MUL_OUT_ID, + IN_WEIGHT_DOWN_ID, + IN_BIAS_DOWN_ID, + IN_DEQSCALE_DOWN, + IN_INDEX_DOWN, + IN_OFFSETX_DOWN, + IN_COMPRESSINFO_DOWN }; + matmulDownNode.outTensorIds = { OUT_RESULT_ID }; + return atb::NO_ERROR; +} + +atb::Status MlpGateLayerV2(const MlpGateParamV2 ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + opGraph.name = "MlpGateLayerV2"; + opGraph.inTensorNum = 19; // 19是输入张量数量 + opGraph.outTensorNum = 1; + size_t interTensorNum = 0; + if (param.noGate) { + interTensorNum = 2; // 2是中间张量数量 + } else if (param.isPack) { + interTensorNum = 5; // 5是中间张量数量 + } else { + interTensorNum = 4; // 4是中间张量数量 + } + opGraph.internalTensorNum = interTensorNum; + atb::Node matmulUpNode; + CHECK_OPERATION_STATUS_RETURN(AddmatmulUpV2Node(matmulUpNode, param)); + opGraph.nodes.push_back(matmulUpNode); + atb::Node splitNode; + atb::Node matmulGateNode; + CHECK_OPERATION_STATUS_RETURN(AddmatmulGateV2Node(splitNode, matmulGateNode, param, opGraph)); + atb::Node actNode; + CHECK_OPERATION_STATUS_RETURN(AddactV2Node(actNode, param)); + opGraph.nodes.push_back(actNode); + atb::Node mulNode; + CHECK_OPERATION_STATUS_RETURN(AddmulV2Node(mulNode, param, opGraph)); + atb::Node matmulDownNode; + CHECK_OPERATION_STATUS_RETURN(AddmatmulDownV2Node(matmulDownNode, param)); + opGraph.nodes.push_back(matmulDownNode); + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + if (inTensorDescs.at(0).dtype == ACL_INT8) { + outTensorDescs.at(0).dtype = ACL_FLOAT16; + } + return atb::NO_ERROR; + }; + + CREATE_OPERATION(opGraph, operation); + return atb::NO_ERROR; +} +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/fusion/mlp_gate_v2.h b/tests/proftest/layer_test_framework/operations/fusion/mlp_gate_v2.h new file mode 100644 index 00000000..91ac9211 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/mlp_gate_v2.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_LAYER_MLP_GATE_V2_H +#define ATB_SPEED_LAYER_MLP_GATE_V2_H + +#include +#include "atb_speed/utils/operation_util.h" +#include "parallel_layer_v2.h" + +namespace atb_speed { +namespace common { +struct MlpGateParamV2 { + atb::infer::ActivationType activationType; + bool transposeB = true; + bool isBias = false; + bool isPack = false; + bool isQuant = false; + bool isSparse = false; + bool noGate = false; + bool isBF16 = false; + CommParam commDownParam; + QuantParam quantUpParam; + QuantParam quantGateParam; + QuantParam quantDownParam; +}; + +atb::Status MlpGateLayerV2(const MlpGateParamV2 ¶m, atb::Operation **operation); + +} // namespace common +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/device_limited_routing.cpp b/tests/proftest/layer_test_framework/operations/fusion/moe/device_limited_routing.cpp new file mode 100644 index 00000000..f00f131e --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/device_limited_routing.cpp @@ -0,0 +1,211 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +* +* * Licensed under the Apache License, Version 2.0 (the "License"); +* * you may not use this file except in compliance with the License. +* * You may obtain a copy of the License at +* * +* * http://www.apache.org/licenses/LICENSE-2.0 +* * +* * Unless required by applicable law or agreed to in writing, software +* * distributed under the License is distributed on an "AS IS" BASIS, +* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* * See the License for the specific language governing permissions and +* * limitations under the License. +* */ + +#include "device_limited_routing.h" +#include +#include + +namespace atb_speed { +namespace deviceLimitedRouting { +enum DeviceLimitedRoutingTensorId : int { + IN_ROUTER_LOGITS, + IN_EXPERT_GROUP, + IN_ONE_HOT, + IN_ZERO_HOT, + OUT_ROUTER_LOGITS, + INTERMIDATE_GROUP_MAX_LOGITS, + DUMMY_IDX, + DUMMY_LOGITS, + INTERMIDATE_TOP_GROUP_IDX, + INTERMIDATE_TOP_EXPERTS_IDX, + INTERMIDATE_EXPERT_MASK, + INTERMIDATE_EXPERT_MASK_FLOAT16, + INTERMIDATE_EXPERT_MASK_FINAL +}; + +static const uint64_t IN_TENSOR_COUNT = 4; +static const uint64_t OUT_TENSOR_COUNT = 1; +static const uint64_t INTERMEDIATE_TENSOR_COUNT = 8; +static const uint64_t OPERATION_COUNT = 7; + +// Op0 - Sort0 +atb::Status CreateGroupTopOne( + atb::Node &groupTopOneNode, atb::GraphParam opGraph, + const DeviceLimitedRoutingParam ¶m, + std::shared_ptr batchDimPtr) +{ + CHECK_PARAM_NE(param.numOfGroups, 0); + atb::infer::SortParam topKExpertParam; + topKExpertParam.num = {1}; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(topKExpertParam, &groupTopOneNode.operation)); + groupTopOneNode.inTensorIds = {IN_ROUTER_LOGITS}; + groupTopOneNode.outTensorIds = {INTERMIDATE_GROUP_MAX_LOGITS, DUMMY_IDX}; + groupTopOneNode.inTensorReshapeFuncs.resize(groupTopOneNode.inTensorIds.size()); + groupTopOneNode.inTensorReshapeFuncs[0] = [batchDimPtr, param](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 3:number of dimensions of the new shape + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = param.numOfGroups; + newShape.dims[2] = oldShape.dims[1] / param.numOfGroups; // 2:second dimension + }; + opGraph.nodes.push_back(groupTopOneNode); + ATB_SPEED_LOG_DEBUG("Reduction calculation success"); + return atb::NO_ERROR; +} + +// Op1 - Sort1 +atb::Status CreateTopkGroup( + const DeviceLimitedRoutingParam ¶m, + atb::Node &topKGroupNode, atb::GraphParam opGraph, + std::shared_ptr batchDimPtr) +{ + atb::infer::SortParam topKGroupParam; + topKGroupParam.num = param.topkGroups; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(topKGroupParam, &topKGroupNode.operation)); + topKGroupNode.inTensorIds = {INTERMIDATE_GROUP_MAX_LOGITS}; + topKGroupNode.outTensorIds = {DUMMY_LOGITS, INTERMIDATE_TOP_GROUP_IDX}; + topKGroupNode.inTensorReshapeFuncs.resize(topKGroupNode.inTensorIds.size()); + topKGroupNode.inTensorReshapeFuncs[0] = [batchDimPtr](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2:number of dimensions of the new shape + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = oldShape.dims[1]; + }; + opGraph.nodes.push_back(topKGroupNode); + ATB_SPEED_LOG_DEBUG("Group selection success"); + return atb::NO_ERROR; +} + +// Op2 - GroupId -> ExpertId +atb::Status CreateGather(std::shared_ptr batchDimPtr, atb::Node &gatherNode, atb::GraphParam opGraph) +{ + atb::infer::GatherParam gatherParam; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatherParam, &gatherNode.operation)); + gatherNode.inTensorIds = {IN_EXPERT_GROUP, INTERMIDATE_TOP_GROUP_IDX}; + gatherNode.outTensorIds = {INTERMIDATE_TOP_EXPERTS_IDX}; + gatherNode.inTensorReshapeFuncs.resize(gatherNode.inTensorIds.size()); + gatherNode.inTensorReshapeFuncs[1] = [batchDimPtr](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; // 3:number of dimensions of the new shape + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + opGraph.nodes.push_back(gatherNode); + ATB_SPEED_LOG_DEBUG("Gather0 calculation success"); + return atb::NO_ERROR; +} + + +// Op3 - ExpertId -> OneHotMask +atb::Status CreateOneHot( + const DeviceLimitedRoutingParam ¶m, + std::shared_ptr batchDimPtr, + atb::Node &oneHotNode, atb::GraphParam opGraph) +{ + CHECK_PARAM_NE(param.topkGroups.at(0), 0); + atb::infer::OnehotParam onehotParam; + onehotParam.axis = 2; // 2:specify axis for oneHotOperation + onehotParam.depth = param.numOfExperts; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(onehotParam, &oneHotNode.operation)); + oneHotNode.inTensorIds = {INTERMIDATE_TOP_EXPERTS_IDX, IN_ONE_HOT, IN_ZERO_HOT}; + oneHotNode.outTensorIds = {INTERMIDATE_EXPERT_MASK}; + oneHotNode.inTensorReshapeFuncs.resize(oneHotNode.inTensorIds.size()); + oneHotNode.inTensorReshapeFuncs[0] = [batchDimPtr, param](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2:number of dimensions of the new shape + newShape.dims[0] = oldShape.dims[0] / param.topkGroups.at(0); + newShape.dims[1] = oldShape.dims[1] * param.topkGroups.at(0); + }; + opGraph.nodes.push_back(oneHotNode); + ATB_SPEED_LOG_DEBUG("Expert Mask created success"); + return atb::NO_ERROR; +} + + +// Op4 - CastforReduceSum +atb::Status CreateCast(atb::Node &castNode, atb::GraphParam opGraph) +{ + atb::infer::ElewiseParam castParam; + castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castParam.outTensorType = ACL_FLOAT16; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castParam, &castNode.operation)); + castNode.inTensorIds = {INTERMIDATE_EXPERT_MASK}; + castNode.outTensorIds = {INTERMIDATE_EXPERT_MASK_FLOAT16}; + opGraph.nodes.push_back(castNode); + ATB_SPEED_LOG_DEBUG("Cast calculation success"); + return atb::NO_ERROR; +} + +// Op5 - Finalize Device-limited Mask +atb::Status CreateMask(atb::Node &maskNode, atb::GraphParam opGraph) +{ + atb::infer::ReduceParam reduceParam; + reduceParam.reduceType = atb::infer::ReduceParam::ReduceType::REDUCE_SUM; + reduceParam.axis = {1}; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(reduceParam, &maskNode.operation)); + maskNode.inTensorIds = {INTERMIDATE_EXPERT_MASK_FLOAT16}; + maskNode.outTensorIds = {INTERMIDATE_EXPERT_MASK_FINAL}; + opGraph.nodes.push_back(maskNode); + ATB_SPEED_LOG_DEBUG("Mask reduction calculation success"); + return atb::NO_ERROR; +} + +// Op6 - Finalize Router Logits +atb::Status CreateElewiseMul(atb::Node &mulNode, atb::GraphParam opGraph) +{ + atb::infer::ElewiseParam elewiseParam; + elewiseParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(elewiseParam, &mulNode.operation)); + mulNode.inTensorIds = {IN_ROUTER_LOGITS, INTERMIDATE_EXPERT_MASK_FINAL}; + mulNode.outTensorIds = {OUT_ROUTER_LOGITS}; + opGraph.nodes.push_back(mulNode); + ATB_SPEED_LOG_DEBUG("ElewiseMul0 calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateDeviceLimitedRoutingOperation(const DeviceLimitedRoutingParam ¶m, atb::Operation **operation) +{ + std::shared_ptr batchDimPtr = std::make_shared(0); + atb::GraphParam opGraph; + opGraph.name = "DeviceLimitedRouting"; + opGraph.inTensorNum = IN_TENSOR_COUNT; + opGraph.outTensorNum = OUT_TENSOR_COUNT; + opGraph.internalTensorNum = INTERMEDIATE_TENSOR_COUNT; + uint64_t nodeSize = OPERATION_COUNT; + opGraph.nodes.resize(nodeSize); + size_t nodeId = 0; + + atb::Node &groupTopOneNode = opGraph.nodes.at(nodeId++); + atb::Node &topKGroupNode = opGraph.nodes.at(nodeId++); + atb::Node &gatherNode = opGraph.nodes.at(nodeId++); + atb::Node &oneHotNode = opGraph.nodes.at(nodeId++); + atb::Node &castNode = opGraph.nodes.at(nodeId++); + atb::Node &maskNode = opGraph.nodes.at(nodeId++); + atb::Node &mulNode = opGraph.nodes.at(nodeId++); + + CHECK_OPERATION_STATUS_RETURN(CreateGroupTopOne(groupTopOneNode, opGraph, param, batchDimPtr)); + CHECK_OPERATION_STATUS_RETURN(CreateTopkGroup(param, topKGroupNode, opGraph, batchDimPtr)); + CHECK_OPERATION_STATUS_RETURN(CreateGather(batchDimPtr, gatherNode, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateOneHot(param, batchDimPtr, oneHotNode, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateCast(castNode, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateMask(maskNode, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateElewiseMul(mulNode, opGraph)); + + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + return atb::NO_ERROR; + }; + + return atb::CreateOperation(opGraph, operation); +} +} +} diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/device_limited_routing.h b/tests/proftest/layer_test_framework/operations/fusion/moe/device_limited_routing.h new file mode 100644 index 00000000..7244f856 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/device_limited_routing.h @@ -0,0 +1,38 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +* +* * Licensed under the Apache License, Version 2.0 (the "License"); +* * you may not use this file except in compliance with the License. +* * You may obtain a copy of the License at +* * +* * http://www.apache.org/licenses/LICENSE-2.0 +* * +* * Unless required by applicable law or agreed to in writing, software +* * distributed under the License is distributed on an "AS IS" BASIS, +* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* * See the License for the specific language governing permissions and +* * limitations under the License. +* */ +#ifndef ATB_SPEED_MODELS_DEVICE_LIMITED_ROUTING_OPERATION_H +#define ATB_SPEED_MODELS_DEVICE_LIMITED_ROUTING_OPERATION_H +#include +#include +#include "atb_speed/log.h" +#include "atb_speed/utils/operation_util.h" + +namespace atb_speed { +namespace deviceLimitedRouting { +struct DeviceLimitedRoutingParam { + int numOfExperts = 64; /// number of experts in total + int numOfGroups = 8; /// number of groups/device in total + atb::SVector topkGroups = {3}; /// number of groups/device selected +}; + +/// This function creates a sub-graph that completes the Device-Limited expert selection mechanism +/// that is first designed for DeepseekV2. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateDeviceLimitedRoutingOperation(const DeviceLimitedRoutingParam ¶m, atb::Operation **operation); + +} +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_collect.cpp b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_collect.cpp new file mode 100644 index 00000000..46a9b4d2 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_collect.cpp @@ -0,0 +1,219 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "all_to_all_collect.h" +#include +#include +#include "operations/fusion/utils.h" + +namespace atb_speed { +namespace common { + +std::map> GetAllToAllCollectInTensorCandidates() +{ + std::map> allToAllCollectInTensorCandidates = { + {"default", { + "in_hiddenstates", "in_moe_out", "in_mask", "in_shuffle_idx", "in_valid_idx"} + }, + }; + return allToAllCollectInTensorCandidates; +} + +std::map> GetAllToAllCollectIntermediateTensorCandidates() +{ + std::map> allToAllCollectIntermediateTensorCandidates = { + {"default", { + "intermediate_recv_output", "intermediate_filtered_output"} + }, + {"has_tp", { + "intermediate_moe_output_partial", "intermediate_moe_output"} + } + }; + return allToAllCollectIntermediateTensorCandidates; +} + +std::map> GetAllToAllCollectOutTensorCandidates() +{ + std::map> allToAllCollectOutTensorCandidates = { + {"default", { + "out"} + }, + }; + return allToAllCollectOutTensorCandidates; +} + +std::map ConstructTensorMap( + const AllToAllCollectParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &interTensorNum) +{ + auto allToAllCollectInTensorCandidates = GetAllToAllCollectInTensorCandidates(); + auto allToAllCollectIntermediateTensorCandidates = GetAllToAllCollectIntermediateTensorCandidates(); + auto allToAllCollectOutTensorCandidates = GetAllToAllCollectOutTensorCandidates(); + + std::vector inTensorList = {}; + std::vector interTensorList = {}; + std::vector outTensorList = {}; + + AddTensorToList(allToAllCollectInTensorCandidates, "default", inTensorList); + AddTensorToList(allToAllCollectIntermediateTensorCandidates, "default", interTensorList); + if (param.hasMoeTp) { + AddTensorToList(allToAllCollectIntermediateTensorCandidates, "has_tp", interTensorList); + } + AddTensorToList(allToAllCollectOutTensorCandidates, "default", outTensorList); + + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + interTensorNum = interTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, interTensorList); +} + +atb::Status CreateReduceScatterData(std::map &tensorMap, const AllToAllCollectParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node allToAllNode; + atb::infer::AllToAllParam allToAllParam; + allToAllParam.rank = param.mlpTpRank; + allToAllParam.rankSize = param.mlpTpSize; + allToAllParam.backend = param.backend; + allToAllParam.hcclComm = param.hcclComm; + allToAllParam.rankTableFile = param.mlpTpRankTableFile; + allToAllParam.commDomain = param.mlpTpDomain; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(allToAllParam, &allToAllNode.operation)); + allToAllNode.inTensorIds = {GetTensorIdx(tensorMap, "in_moe_out")}; + allToAllNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_moe_output_partial")}; + opGraph.nodes.push_back(allToAllNode); + + atb::Node reduceNode; + atb::infer::ReduceParam reduceParam; + reduceParam.reduceType = atb::infer::ReduceParam::ReduceType::REDUCE_SUM; + reduceParam.axis.resize(1); // 调整 SVector 的大小 + reduceParam.axis[0] = 0; // 将第一个元素设置为 1 + CHECK_OPERATION_STATUS_RETURN(CreateOperation(reduceParam, &reduceNode.operation)); + reduceNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_moe_output_partial")}; + reduceNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_moe_output")}; + reduceNode.inTensorReshapeFuncs.resize(reduceNode.inTensorIds.size()); + reduceNode.inTensorReshapeFuncs[0] = [=] (const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 3: dimNum + newShape.dims[0] = param.mlpTpSize; + newShape.dims[1] = oldShape.dims[0] / param.mlpTpSize; + newShape.dims[2] = oldShape.dims[1]; // 2: dim 2 + }; + opGraph.nodes.push_back(reduceNode); + return atb::NO_ERROR; +} + +atb::Status CreateAll2AllCollectData(std::map &tensorMap, const AllToAllCollectParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node allToAllNode; + atb::infer::AllToAllParam allToAllParam; + allToAllParam.rank = param.moeEpRank; + allToAllParam.rankSize = param.moeEpSize; + allToAllParam.backend = param.backend; + allToAllParam.hcclComm = param.hcclComm; + allToAllParam.rankTableFile = param.moeEpRankTableFile; + allToAllParam.commDomain = param.moeEpDomain; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(allToAllParam, &allToAllNode.operation)); + if (param.hasMoeTp) { + allToAllNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_moe_output")}; + } else { + allToAllNode.inTensorIds = {GetTensorIdx(tensorMap, "in_moe_out")}; + } + allToAllNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_recv_output")}; + opGraph.nodes.push_back(allToAllNode); + return atb::NO_ERROR; +} + +atb::Status CreateFilteredRecvData(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node node; + atb::infer::ElewiseParam param; + param.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(param, &node.operation)); + node.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_recv_output"), + GetTensorIdx(tensorMap, "in_mask")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_filtered_output")}; + node.inTensorReshapeFuncs.resize(node.inTensorIds.size()); + node.inTensorReshapeFuncs[1] = [=] (const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // dimNum: 2 + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = 1; + }; + opGraph.nodes.push_back(node); + return atb::NO_ERROR; +} + + +atb::Status CreateEmptyTensor(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node mapNode; + atb::infer::ElewiseParam mapParam; + mapParam.mulsParam.varAttr = 0.0; + mapParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MULS; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(mapParam, &mapNode.operation)); + mapNode.inTensorIds = {GetTensorIdx(tensorMap, "in_hiddenstates")}; + mapNode.outTensorIds = {GetTensorIdx(tensorMap, "out")}; + opGraph.nodes.push_back(mapNode); + return atb::NO_ERROR; +} + + +atb::Status CreateIndexAdd(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node indexAddNode; + atb::infer::IndexAddParam indexAddParam; + indexAddParam.indexType = atb::infer::IndexAddParam::IndexType::INDEX_ADD_VALID; + indexAddParam.axis = 0; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(indexAddParam, &indexAddNode.operation)); + indexAddNode.inTensorIds = {GetTensorIdx(tensorMap, "out"), + GetTensorIdx(tensorMap, "in_shuffle_idx"), + GetTensorIdx(tensorMap, "intermediate_filtered_output"), + GetTensorIdx(tensorMap, "in_valid_idx")}; + indexAddNode.outTensorIds = {GetTensorIdx(tensorMap, "out")}; + opGraph.nodes.push_back(indexAddNode); + return atb::NO_ERROR; +} + +atb::Status CreateAllToAllCollectOperation(const AllToAllCollectParam ¶m, atb::Operation **operation) +{ + ATB_SPEED_LOG_DEBUG("CreateAllToAllCollectOperation Start"); + atb::GraphParam opGraph; + opGraph.name = "AllToAllCollect"; + std::map tensorMap = ConstructTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + if (param.hasMoeTp) { + CHECK_OPERATION_STATUS_RETURN(CreateReduceScatterData(tensorMap, param, opGraph)); + } + CHECK_OPERATION_STATUS_RETURN(CreateAll2AllCollectData(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsAfterComm(opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateFilteredRecvData(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateEmptyTensor(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateIndexAdd(tensorMap, opGraph)); + + opGraph.inferShapeFunc = [=] (const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + return atb::NO_ERROR; + }; + CREATE_OPERATION(opGraph, operation); + ATB_SPEED_LOG_DEBUG("CreateAllToAllCollectOperation success"); + return atb::NO_ERROR; +} +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_collect.h b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_collect.h new file mode 100644 index 00000000..57220b7f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_collect.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_ALL_TO_ALL_COLLECT_OPERATION_H +#define ATB_SPEED_MODELS_ALL_TO_ALL_COLLECT_OPERATION_H +#include +#include +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/log.h" + +namespace atb_speed { +namespace common { +struct AllToAllCollectParam { + int32_t topk = 2; + int numOfExperts = 8; + + std::string backend = "hccl"; + HcclComm hcclComm = nullptr; + bool hasMoeEp = false; + int moeEpRank = 0; + int moeEpSize = 1; + std::string moeEpDomain = ""; + std::string moeEpRankTableFile = ""; + + bool hasMoeTp = false; + int moeTpRank = 0; + int moeTpSize = 1; + std::string moeTpDomain = ""; + std::string moeTpRankTableFile = ""; + + bool hasMlpTp = false; + int mlpTpRank = 0; + int mlpTpSize = 1; + std::string mlpTpDomain = ""; + std::string mlpTpRankTableFile = ""; +}; + +atb::Status CreateAllToAllCollectOperation(const AllToAllCollectParam ¶m, atb::Operation **operation); +} +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_dispatch.cpp b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_dispatch.cpp new file mode 100644 index 00000000..79a4b578 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_dispatch.cpp @@ -0,0 +1,196 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "all_to_all_dispatch.h" +#include +#include +#include "operations/fusion/utils.h" +namespace atb_speed { +namespace common { + +std::map> GetAllToAllDispatchInTensorCandidates() +{ + std::map> allToAllDispatchInTensorCandidates = { + {"default", { + "in_hiddenstatus", "in_selected_experts", "in_expert_weight", + "in_hidden_shuffle_idx", "in_expert_shuffle_idx", + "in_zero_hot", "in_one_hot"} + }, + }; + return allToAllDispatchInTensorCandidates; +} + + +std::map> GetAllToAllDispatchInterTensorCandidates() +{ + std::map> allToAllDispatchInterTensorCandidates = { + {"default", { + "intermediate_send_expert", "intermediate_send_hiddenstatus"} + }, + }; + return allToAllDispatchInterTensorCandidates; +} + +std::map> GetAllToAllDispatchOutTensorCandidates() +{ + std::map> allToAllDispatchOutTensorCandidates = { + {"default", { + "out_hiddenstates", "out_selected_experts", "out_expert_weight"} + }, + }; + return allToAllDispatchOutTensorCandidates; +} + +std::map ConstructAllToAllDispatchTensorMap( + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &interTensorNum) +{ + auto allToAllDispatchInTensorCandidates = GetAllToAllDispatchInTensorCandidates(); + auto allToAllDispatchInterTensorCandidates = GetAllToAllDispatchInterTensorCandidates(); + auto allToAllDispatchOutTensorCandidates = GetAllToAllDispatchOutTensorCandidates(); + + std::vector inTensorList = {}; + std::vector interTensorList = {}; + std::vector outTensorList = {}; + + AddTensorToList(allToAllDispatchInTensorCandidates, "default", inTensorList); + AddTensorToList(allToAllDispatchInterTensorCandidates, "default", interTensorList); + AddTensorToList(allToAllDispatchOutTensorCandidates, "default", outTensorList); + + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + interTensorNum = interTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, interTensorList); +} + +atb::Status CreateExpertGather(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node gatherNode; + atb::infer::GatherParam gatherParam; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatherParam, &gatherNode.operation)); + gatherNode.inTensorIds = {GetTensorIdx(tensorMap, "in_selected_experts"), + GetTensorIdx(tensorMap, "in_expert_shuffle_idx")}; + gatherNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_send_expert")}; + gatherNode.inTensorReshapeFuncs.resize(gatherNode.inTensorIds.size()); + gatherNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + opGraph.nodes.push_back(gatherNode); + return atb::NO_ERROR; +} + +atb::Status CreateExpertWeightGather(std::map &tensorMap, + atb::GraphParam &opGraph) +{ + atb::Node gatherNode; + atb::infer::GatherParam gatherParam; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatherParam, &gatherNode.operation)); + gatherNode.inTensorIds = {GetTensorIdx(tensorMap, "in_expert_weight"), + GetTensorIdx(tensorMap, "in_expert_shuffle_idx")}; + gatherNode.outTensorIds = {GetTensorIdx(tensorMap, "out_expert_weight")}; + gatherNode.inTensorReshapeFuncs.resize(gatherNode.inTensorIds.size()); + gatherNode.inTensorReshapeFuncs[0] = [=] (const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + opGraph.nodes.push_back(gatherNode); + return atb::NO_ERROR; +} + +atb::Status CreateHiddenGather(std::map &tensorMap, + atb::GraphParam &opGraph) +{ + atb::Node gatherNode; + atb::infer::GatherParam gatherParam; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatherParam, &gatherNode.operation)); + gatherNode.inTensorIds = {GetTensorIdx(tensorMap, "in_hiddenstatus"), + GetTensorIdx(tensorMap, "in_hidden_shuffle_idx")}; + gatherNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_send_hiddenstatus")}; + opGraph.nodes.push_back(gatherNode); + return atb::NO_ERROR; +} + +atb::Status CreateAll2AllDispatchData(std::map &tensorMap, + const AllToAllDispatchParam ¶m, atb::GraphParam &opGraph) +{ + atb::Node allToAllNode; + atb::infer::AllToAllParam allToAllParam; + allToAllParam.rank = param.moeEpRank; + allToAllParam.rankSize = param.moeEpSize; + allToAllParam.backend = param.backend; + allToAllParam.hcclComm = param.hcclComm; + allToAllParam.rankTableFile = param.moeEpRankTableFile; + allToAllParam.commDomain=param.moeEpDomain; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(allToAllParam, &allToAllNode.operation)); + allToAllNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_send_hiddenstatus")}; + allToAllNode.outTensorIds = {GetTensorIdx(tensorMap, "out_hiddenstates")}; + opGraph.nodes.push_back(allToAllNode); + return atb::NO_ERROR; +} + +atb::Status CreateAll2AllDispatchExpert(std::map &tensorMap, + const AllToAllDispatchParam ¶m, atb::GraphParam &opGraph) +{ + atb::Node allToAllNode; + atb::infer::AllToAllParam allToAllParam; + allToAllParam.rank = param.moeEpRank; + allToAllParam.rankSize = param.moeEpSize; + allToAllParam.backend = param.backend; + allToAllParam.hcclComm = param.hcclComm; + allToAllParam.rankTableFile = param.moeEpRankTableFile; + allToAllParam.commDomain = param.moeEpDomain; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(allToAllParam, &allToAllNode.operation)); + allToAllNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_send_expert")}; + allToAllNode.outTensorIds = {GetTensorIdx(tensorMap, "out_selected_experts")}; + opGraph.nodes.push_back(allToAllNode); + return atb::NO_ERROR; +} + +atb::Status CreateAllToAllDispatchOperation(const AllToAllDispatchParam ¶m, atb::Operation **operation) +{ + ATB_SPEED_LOG_DEBUG("CreateAllToAllDispatchOperation Start"); + atb::GraphParam opGraph; + opGraph.name = "AllToAllDispatch"; + std::map tensorMap = ConstructAllToAllDispatchTensorMap( + opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + + CHECK_OPERATION_STATUS_RETURN(CreateExpertGather(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateExpertWeightGather(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateHiddenGather(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateAll2AllDispatchData(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateAll2AllDispatchExpert(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsAfterComm(opGraph)); + + opGraph.inferShapeFunc = [=] (const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + outTensorDescs.at(1) = inTensorDescs.at(1); + outTensorDescs.at(2) = inTensorDescs.at(2); /// 2: dim 2 + + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(3).shape.dims[0]; // 3: dim 3 + outTensorDescs.at(1).shape.dims[0] = inTensorDescs.at(3).shape.dims[0]; // 2: dim2, 3: dim 3 + outTensorDescs.at(2).shape.dims[0] = inTensorDescs.at(3).shape.dims[0]; // 2: dim2, 3: dim 3 + + outTensorDescs.at(1).shape.dims[1] = 1; + outTensorDescs.at(2).shape.dims[1] = 1; // 2: dim 2 + return atb::NO_ERROR; + }; + CREATE_OPERATION(opGraph, operation); + return atb::NO_ERROR; +} +} +} // \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_dispatch.h b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_dispatch.h new file mode 100644 index 00000000..f147253a --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_dispatch.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_ALL_TO_ALL_DISPATCH_OPERATION_H +#define ATB_SPEED_MODELS_ALL_TO_ALL_DISPATCH_OPERATION_H +#include +#include +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/log.h" + +namespace atb_speed { +namespace common { +struct AllToAllDispatchParam { + int topk = 1; + int numOfExperts = 8; + std::string backend = "hccl"; + HcclComm hcclComm = nullptr; + bool hasMoeEp = false; + int moeEpRank = 0; + int moeEpSize = 1; + std::string moeEpDomain = ""; + std::string moeEpRankTableFile = ""; + + bool hasMlpTp = false; + int mlpTpRank = 0; + int mlpTpSize = 1; + std::string mlpTpDomain = ""; + std::string mlpTpRankTableFile = ""; +}; + +atb::Status CreateAllToAllDispatchOperation(const AllToAllDispatchParam ¶m, atb::Operation **operation); +} +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_meta.cpp b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_meta.cpp new file mode 100644 index 00000000..ff0cf98a --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_meta.cpp @@ -0,0 +1,330 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "all_to_all_meta.h" +#include +#include +#include "operations/fusion/utils.h" +#include "operations/aclnn/ops/cast_operation.h" + +namespace atb_speed { +namespace common { +std::map> GetAllToAllMetaInTensorCandidates() +{ + std::map> allToAllMetaInTensorCandidates = { + {"default", { + "in_group_count", "in_idx", "in_zero_hot"} + }, + }; + return allToAllMetaInTensorCandidates; +} + +std::map> GetAllToAllMetaIntermediateTensorCandidates() +{ + std::map> allToAllMetaIntermediateTensorCandidates = { + {"default", { + "intermediate_buffer_idx", "intermediate_buffer_idx_int64", "intermediate_group_count_int64", + "intermediate_shuffle_idx_int64", "intermediate_zero_hot_int64", "intermediate_shuffle_filter_mask", + "intermediate_shuffle_idx_int32", "intermediate_shuffle_idx_float16", "intermediate_one_mask", + "intermediate_shuffle_weight"} + }, + }; + return allToAllMetaIntermediateTensorCandidates; +} + +std::map> GetAllToAllMetaOutTensorCandidates() +{ + std::map> allToAllMetaOutTensorCandidates = { + {"default", { + "out_shuffle_idx_for_device_buffer", "out_shuffle_weight_mask", "out_valid_idx"} + }, + }; + return allToAllMetaOutTensorCandidates; +} + +std::map ConstructAllToAllMetaTensorMap( + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &interTensorNum) +{ + auto allToAllMetaInTensorCandidates = GetAllToAllMetaInTensorCandidates(); + auto allToAllMetaIntermediateTensorCandidates = GetAllToAllMetaIntermediateTensorCandidates(); + auto allToAllMetaOutTensorCandidates = GetAllToAllMetaOutTensorCandidates(); + + std::vector inTensorList = {}; + std::vector interTensorList = {}; + std::vector outTensorList = {}; + + AddTensorToList(allToAllMetaInTensorCandidates, "default", inTensorList); + AddTensorToList(allToAllMetaIntermediateTensorCandidates, "default", interTensorList); + AddTensorToList(allToAllMetaOutTensorCandidates, "default", outTensorList); + + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + interTensorNum = interTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, interTensorList); +} + +atb::Status CreateOutValidIdx(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + // 需要+1 done 在python层进行了+1 + atb::infer::SliceParam sliceParam; + atb::Node &sliceNode = opGraph.nodes.at(nodeId++); + sliceParam.offsets.resize(1); + sliceParam.offsets[0] = -1; + sliceParam.size.resize(1); + sliceParam.size[0] = 1; + CreateOperation(sliceParam, &sliceNode.operation); + sliceNode.inTensorIds = {GetTensorIdx(tensorMap, "in_idx")}; + sliceNode.outTensorIds = {GetTensorIdx(tensorMap, "out_valid_idx")}; + return atb::NO_ERROR; +} + +atb::Status CreateBufferIdx(std::map &tensorMap, + const AllToAllMetaParam ¶m, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::infer::SliceParam sliceParam; + atb::Node &sliceNode = opGraph.nodes.at(nodeId++); + sliceParam.offsets.resize(2); // 2: dimNum + sliceParam.offsets[0] = 0; + sliceParam.offsets[1] = 0; + sliceParam.size.resize(2); // 2: dimNum + sliceParam.size[0] = 1; + sliceParam.size[1] = -1; + CreateOperation(sliceParam, &sliceNode.operation); + sliceNode.inTensorIds = {GetTensorIdx(tensorMap, "in_idx")}; + sliceNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_buffer_idx")}; + sliceNode.inTensorReshapeFuncs.resize(sliceNode.inTensorIds.size()); + sliceNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2: dimNum + newShape.dims[0] = param.worldSize; + newShape.dims[1] = oldShape.dims[0] / param.worldSize; + }; + return atb::NO_ERROR; +} + +atb::Status CreateBufferIdx64(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &castNode = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam castParam; + castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castParam.outTensorType = ACL_INT64; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castParam, &castNode.operation)); + castNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_buffer_idx")}; + castNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_buffer_idx_int64")}; + return atb::NO_ERROR; +} + +atb::Status CreateGroupCount64(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &castNode = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam castParam; + castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castParam.outTensorType = ACL_INT64; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castParam, &castNode.operation)); + castNode.inTensorIds = {GetTensorIdx(tensorMap, "in_group_count")}; + castNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_group_count_int64")}; + return atb::NO_ERROR; +} + +atb::Status CreateShuffleIdx64(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + // 这个shuffleidx 可能是负数 需进行filter操作 + atb::Node &node = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam param; + param.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_SUB; + CreateOperation(param, &node.operation); + node.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_group_count_int64"), + GetTensorIdx(tensorMap, "intermediate_buffer_idx_int64")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx_int64")}; + node.inTensorReshapeFuncs.resize(node.inTensorIds.size()); + node.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2: dimNum + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = 1; + }; + return atb::NO_ERROR; +} + +atb::Status CreateZeroHot64(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &castNode = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam castParam; + castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castParam.outTensorType = ACL_INT64; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castParam, &castNode.operation)); + castNode.inTensorIds = {GetTensorIdx(tensorMap, "in_zero_hot")}; + castNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_zero_hot_int64")}; + return atb::NO_ERROR; +} + +atb::Status CreateShuffleIdxFilterMask(std::map &tensorMap, + size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &node = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam lessParam; + lessParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_LESS; + CreateOperation(lessParam, &node.operation); + node.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx_int64"), + GetTensorIdx(tensorMap, "intermediate_zero_hot_int64")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_filter_mask")}; + return atb::NO_ERROR; +} + +atb::Status CreateShuffleIdx32(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &castNode = opGraph.nodes.at(nodeId++); + + atb::infer::ElewiseParam castParam; + castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castParam.outTensorType = ACL_INT32; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castParam, &castNode.operation)); + castNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx_int64")}; + castNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx_int32")}; + return atb::NO_ERROR; +} + +atb::Status CreateShuffleIdx(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &node = opGraph.nodes.at(nodeId++); + atb::infer::FillParam param; + param.withMask = true; + param.value.resize(1); + param.value[0] = 0; + CreateOperation(param, &node.operation); + node.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx_int32"), + GetTensorIdx(tensorMap, "intermediate_shuffle_filter_mask")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "out_shuffle_idx_for_device_buffer")}; + node.inTensorReshapeFuncs.resize(node.inTensorIds.size()); + node.inTensorReshapeFuncs[0] = [=] (const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + node.inTensorReshapeFuncs[1] = [=] (const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + return atb::NO_ERROR; +} + +atb::Status CreateShuffleIdx16(std::map &tensorMap, + const AllToAllMetaParam ¶m, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &castNode = opGraph.nodes.at(nodeId++); + atb_speed::common::AclNNCastParam castParam; + castParam.dtype = param.isBF16 ? aclDataType::ACL_BF16 : aclDataType::ACL_FLOAT16; + castNode.operation = new atb_speed::common::CastOperation("CastNode", castParam); + castNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx_int32")}; + castNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx_float16")}; + castNode.inTensorReshapeFuncs.resize(castNode.inTensorIds.size()); + castNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + return atb::NO_ERROR; +} + +atb::Status CreateOneMask(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &node = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam param; + param.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_LOGICAL_NOT; + CreateOperation(param, &node.operation); + node.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_filter_mask")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_one_mask")}; + node.inTensorReshapeFuncs.resize(node.inTensorIds.size()); + node.inTensorReshapeFuncs[0] = [=] (const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + return atb::NO_ERROR; +} + +atb::Status CreateShuffleWeightZero(std::map &tensorMap, + size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &node = opGraph.nodes.at(nodeId++); + atb::infer::FillParam param; + param.withMask = true; + param.value.resize(1); + param.value[0] = 0; + CreateOperation(param, &node.operation); + node.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx_float16"), + GetTensorIdx(tensorMap, "intermediate_shuffle_filter_mask")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_weight")}; + node.inTensorReshapeFuncs.resize(node.inTensorIds.size()); + node.inTensorReshapeFuncs[1] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + return atb::NO_ERROR; +} + +atb::Status CreateShuffleWeightOne(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &node = opGraph.nodes.at(nodeId++); + atb::infer::FillParam param; + param.withMask = true; + param.value.resize(1); + param.value[0] = 1; + CreateOperation(param, &node.operation); + node.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_weight"), + GetTensorIdx(tensorMap, "intermediate_one_mask")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "out_shuffle_weight_mask")}; + return atb::NO_ERROR; +} + +atb::Status CreateAllToAllMetaOperation(const AllToAllMetaParam ¶m, atb::Operation **operation) +{ + ATB_SPEED_LOG_DEBUG("CreateAllToAllMetaOperation Start"); + atb::GraphParam opGraph; + opGraph.name = "AllToAllMeta"; + std::map tensorMap = ConstructAllToAllMetaTensorMap( + opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + uint64_t nodeCount = 13; + opGraph.nodes.resize(nodeCount); + size_t nodeId = 0; + // output 1 + CHECK_OPERATION_STATUS_RETURN(CreateOutValidIdx(tensorMap, nodeId, opGraph)); + // output 2 + CHECK_OPERATION_STATUS_RETURN(CreateBufferIdx(tensorMap, param, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateBufferIdx64(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateGroupCount64(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateShuffleIdx64(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateZeroHot64(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateShuffleIdxFilterMask(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateShuffleIdx32(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateShuffleIdx(tensorMap, nodeId, opGraph)); + // output 3 + CHECK_OPERATION_STATUS_RETURN(CreateShuffleIdx16(tensorMap, param, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateOneMask(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateShuffleWeightZero(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateShuffleWeightOne(tensorMap, nodeId, opGraph)); + + opGraph.inferShapeFunc = [=] (const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(1); + outTensorDescs.at(1) = inTensorDescs.at(1); + outTensorDescs.at(1).dtype = param.isBF16 ? ACL_BF16 : ACL_FLOAT16; + outTensorDescs.at(2) = inTensorDescs.at(1); // 2: dim 2 + outTensorDescs.at(2).shape.dimNum = 1; // 2: dim 2 + outTensorDescs.at(2).shape.dims[0] = 1; // 2: dim 2 + return atb::NO_ERROR; + }; + CREATE_OPERATION(opGraph, operation); + ATB_SPEED_LOG_DEBUG("CreateAllToAllMetaOperation success"); + return atb::NO_ERROR; +} +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_meta.h b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_meta.h new file mode 100644 index 00000000..f9c25000 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/all_to_all_meta.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_ALL_TO_ALL_META_OPERATION_H +#define ATB_SPEED_MODELS_ALL_TO_ALL_META_OPERATION_H +#include +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/log.h" + + +namespace atb_speed { +namespace common { +struct AllToAllMetaParam { + bool enableCompression = false; + + int32_t topk = 2; + int numOfExperts = 8; + bool isBF16 = false; + int rank = 0; + int worldSize = 1; +}; + +atb::Status CreateAllToAllMetaOperation(const AllToAllMetaParam ¶m, atb::Operation **operation); +} +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/data_preparation.cpp b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/data_preparation.cpp new file mode 100644 index 00000000..c1718cb9 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/data_preparation.cpp @@ -0,0 +1,166 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "data_preparation.h" +#include +#include +#include "operations/fusion/utils.h" +#include "operations/aclnn/ops/moe_init_routing_operation.h" +#include "operations/aclnn/ops/moe_compute_expert_tokens_operation.h" + +namespace atb_speed { +namespace common { +std::map> GetDataPreparationInTensorCandidates() +{ + std::map> dataPreparationInTensorCandidates = { + {"default", { + "in_selected_experts", "in_idx", "in_one_hot", "in_zero_hot"}}, + }; + return dataPreparationInTensorCandidates; +} + +std::map> GetDataPreparationInterTensorCandidates() +{ + std::map> dataPreparationInterTensorCandidates = { + {"default", { + "intermediate_group_count"}}, + }; + return dataPreparationInterTensorCandidates; +} + +std::map> GetDataPreparationOutTensorCandidates() +{ + std::map> dataPreparationOutTensorCandidates = { + {"default", { + "out_shuffle_idx", "out_expert_idx", "out_group_count"}}, + }; + return dataPreparationOutTensorCandidates; +} + +std::map ConstructDataPreparationTensorMap( + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &interTensorNum) +{ + auto dataPreparationInTensorCandidates = GetDataPreparationInTensorCandidates(); + auto dataPreparationInterTensorCandidates = GetDataPreparationInterTensorCandidates(); + auto dataPreparationOutTensorCandidates = GetDataPreparationOutTensorCandidates(); + + std::vector inTensorList = {}; + std::vector interTensorList = {}; + std::vector outTensorList = {}; + + AddTensorToList(dataPreparationInTensorCandidates, "default", inTensorList); + AddTensorToList(dataPreparationInterTensorCandidates, "default", interTensorList); + AddTensorToList(dataPreparationOutTensorCandidates, "default", outTensorList); + + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + interTensorNum = interTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, interTensorList); +} + +atb::Status CreateDeviceGating(std::map &tensorMap, const DataPreparationParam ¶m, + size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &gatingNode = opGraph.nodes.at(nodeId++); + atb::infer::GatingParam gatingParam; + gatingParam.topkExpertNum = param.topk; + gatingParam.cumSumNum = param.numOfExperts; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatingParam, &gatingNode.operation)); + + gatingNode.inTensorIds = {GetTensorIdx(tensorMap, "in_selected_experts"), + GetTensorIdx(tensorMap, "in_idx")}; + gatingNode.outTensorIds = {GetTensorIdx(tensorMap, "out_shuffle_idx"), + GetTensorIdx(tensorMap, "intermediate_group_count"), + GetTensorIdx(tensorMap, "out_expert_idx")}; + + gatingNode.inTensorReshapeFuncs.resize(gatingNode.inTensorIds.size()); + gatingNode.inTensorReshapeFuncs[0] = [=] (const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; // dimNum: 1 + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + + ATB_SPEED_LOG_DEBUG("Gating calculation success"); + return atb::NO_ERROR; +} + + +// 对 Group Count 进行处理,返回是worldsize的大小, 获取设备通信数 +atb::Status CreateGroupSlice(std::map &tensorMap, const DataPreparationParam ¶m, + size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::infer::SliceParam sliceParam; + atb::Node &sliceNode = opGraph.nodes.at(nodeId++); + + sliceParam.offsets.resize(2); // 2: dimNum + sliceParam.offsets[0] = 0; + sliceParam.offsets[1] = -1; + + sliceParam.size.resize(2); // 2: dimNum + sliceParam.size[0] = param.worldSize; + sliceParam.size[1] = 1; + + CreateOperation(sliceParam, &sliceNode.operation); + + sliceNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_group_count")}; + sliceNode.outTensorIds = {GetTensorIdx(tensorMap, "out_group_count")}; + + sliceNode.inTensorReshapeFuncs.resize(sliceNode.inTensorIds.size()); + sliceNode.inTensorReshapeFuncs[0] = [=] (const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2: dimNum + newShape.dims[0] = param.worldSize; + newShape.dims[1] = oldShape.dims[0] / param.worldSize; + }; + + ATB_SPEED_LOG_DEBUG("CreateGroupSlice, Get Device Token Num"); + return atb::NO_ERROR; +} + +atb::Status CreateDataPreparationOperation(const DataPreparationParam ¶m, atb::Operation **operation) +{ + ATB_SPEED_LOG_DEBUG("CreateDataPreparationOperation Start"); + atb::GraphParam opGraph; + opGraph.name = "DataPreparation"; + std::map tensorMap = ConstructDataPreparationTensorMap( + opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + + uint64_t nodeCount = 2; + opGraph.nodes.resize(nodeCount); + size_t nodeId = 0; + + CHECK_OPERATION_STATUS_RETURN(CreateDeviceGating(tensorMap, param, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateGroupSlice(tensorMap, param, nodeId, opGraph)); + + opGraph.inferShapeFunc = [=] (const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + size_t shape = inTensorDescs.at(1).shape.dims[0]; + outTensorDescs.at(0) = inTensorDescs.at(1); + outTensorDescs.at(0).shape.dimNum = 1; + outTensorDescs.at(0).shape.dims[0] = shape; + outTensorDescs.at(1) = inTensorDescs.at(1); + outTensorDescs.at(1).shape.dimNum = 1; + outTensorDescs.at(1).shape.dims[0] = shape; + outTensorDescs.at(2) = inTensorDescs.at(1); // 2: dim 2 + outTensorDescs.at(2).shape.dimNum = 1; // 2: dim 2 + outTensorDescs.at(2).shape.dims[0] = param.worldSize; // 2: dim 2 + return atb::NO_ERROR; + }; + + CREATE_OPERATION(opGraph, operation); + ATB_SPEED_LOG_DEBUG("CreateDataPreparationOperation success"); + return atb::NO_ERROR; +} +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/data_preparation.h b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/data_preparation.h new file mode 100644 index 00000000..578ec448 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/data_preparation.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_DATA_PREPARATION_OPERATION_H +#define ATB_SPEED_MODELS_DATA_PREPARATION_OPERATION_H +#include +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/log.h" + + +namespace atb_speed { +namespace common { +struct DataPreparationParam { + int32_t topk = 2; + int numOfExperts = 8; + int worldSize = 1; + int rank = 0; + bool mixSharedRouting = false; +}; + +atb::Status CreateDataPreparationOperation(const DataPreparationParam ¶m, atb::Operation **operation); +} +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/dynamic_ep_moe.cpp b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/dynamic_ep_moe.cpp new file mode 100644 index 00000000..140cb05b --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/dynamic_ep_moe.cpp @@ -0,0 +1,482 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +* +* * Licensed under the Apache License, Version 2.0 (the "License"); +* * you may not use this file except in compliance with the License. +* * You may obtain a copy of the License at +* * +* * http://www.apache.org/licenses/LICENSE-2.0 +* * +* * Unless required by applicable law or agreed to in writing, software +* * distributed under the License is distributed on an "AS IS" BASIS, +* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* * See the License for the specific language governing permissions and +* * limitations under the License. +* */ +#include "dynamic_ep_moe.h" +#include +#include +#include "operations/aclnn/ops/argsort_operation.h" +#include "operations/aclnn/ops/grouped_matmul_operation.h" +#include "operations/fusion/moe/moe_mlp.h" +#include "data_preparation.h" +#include "all_to_all_meta.h" +#include "all_to_all_dispatch.h" +#include "all_to_all_collect.h" +#include "operations/fusion/utils.h" +#include "fused_alltoall_gmm.h" + + +namespace atb_speed { +namespace common { + +std::map> GetDynamicEpMoEInTensorCandidates() +{ + std::map> dynamicEpMoEInTensorCandidates = { + {"default", { + "in_hiddenstatus", "in_mlp_gateup_weight_expert", "in_mlp_gateup_bias_expert", + "in_mlp_gateup_descale_expert", "in_mlp_gateup_offset_expert", "in_mlp_gateup_scale_expert", + "in_mlp_gateup_compress_idx_expert", "in_mlp_down_weight_expert", + "in_mlp_down_bias_expert", "in_mlp_down_descale_expert", "in_mlp_down_offset_expert", + "in_mlp_down_scale_expert", "in_mlp_down_compress_idx_expert", "in_expert_array", "in_selected_experts", + "in_expert_weight", "in_one_hot", "in_zero_hot"} + }, + {"ep", { + "in_start_expert_idx", "in_device_expert_count", "in_padding_idx"} + }, + {"dynamic_ep", { + "in_buffer_idx", "in_moe_idx"} + }, + }; + return dynamicEpMoEInTensorCandidates; +} + +std::map> GetDynamicEpMoEInterTensorCandidates() +{ + std::map> dynamicEpMoEInterTensorCandidates = { + {"default", { + "intermediate_shuffle_idx", "intermediate_expert_shuffle_idx", "intermediate_valid_idx", + "intermediate_buffer_idx", "intermediate_shuffle_idx_1", + "intermediate_shuffle_idx_2", "intermediate_expert_shuffle_idx_1", + "intermediate_group_count", "intermediate_shuffle_weight", "intermediate_recv_hiddenstatus", + "intermediate_recv_selected_experts", "intermediate_experts_weight", "intermediate_moe_output"} + }, + }; + return dynamicEpMoEInterTensorCandidates; +} + +std::map> GetDynamicEpMoEOutTensorCandidates() +{ + std::map> dynamicEpMoEOutTensorCandidates = { + {"default", { + "out_hiddenstates"} + }, + }; + return dynamicEpMoEOutTensorCandidates; +} + +std::map ConstructDynamicEpTensorMap( + const DynamicEpMoEParam ¶m, uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &interTensorNum) +{ + auto dynamicEpMoEInTensorCandidates = GetDynamicEpMoEInTensorCandidates(); + auto dynamicEpMoEInterTensorCandidates = GetDynamicEpMoEInterTensorCandidates(); + auto dynamicEpMoEOutTensorCandidates = GetDynamicEpMoEOutTensorCandidates(); + std::vector inTensorList = {}; + std::vector interTensorList = {}; + std::vector outTensorList = {}; + AddTensorToList(dynamicEpMoEInTensorCandidates, "default", inTensorList); + if (param.hasMoeEp) { + AddTensorToList(dynamicEpMoEInTensorCandidates, "ep", inTensorList); + if (param.isDynamicEp) { + AddTensorToList(dynamicEpMoEInTensorCandidates, "dynamic_ep", inTensorList); + } + } + if (param.hasMoeEp && param.isDynamicEp && !param.enableMoeDistribute && !param.enableLcocAll2All) { + AddTensorToList(dynamicEpMoEInterTensorCandidates, "default", interTensorList); + } + AddTensorToList(dynamicEpMoEOutTensorCandidates, "default", outTensorList); + if (param.enableExpertCumSumOutput) { + outTensorList.push_back("out_gmm_cumsum_list"); + } + + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + interTensorNum = interTensorList.size(); + return GetTensorMap(inTensorList, outTensorList, interTensorList); +} + +atb::Status CreateFusedAllToAllMlp(std::map &tensorMap, const DynamicEpMoEParam ¶m, + size_t &nodeId, atb::GraphParam &opGraph) +{ + auto &expertNode = opGraph.nodes.at(nodeId++); + atb_speed::common::All2AllMatmulParam mlpExpertParam; + mlpExpertParam.topk = param.topk; + mlpExpertParam.scaledTopk = param.scaledTopk; + mlpExpertParam.numOfDeviceExperts = param.numOfDeviceExperts; + mlpExpertParam.numOfExperts = param.numOfExperts; + mlpExpertParam.gateUpTransposeB = param.gateUpTransposeB; + mlpExpertParam.downTransposeB = param.downTransposeB; + mlpExpertParam.moeEpRank = param.moeEpRank; + mlpExpertParam.moeEpSize = param.moeEpSize; + mlpExpertParam.lcclMoeEpDomain = param.lcclMoeEpDomain; + mlpExpertParam.lcclMoeEpHcclComm = param.lcclMoeEpHcclComm; + atb_speed::common::CreateAll2AllMatmulOperation(mlpExpertParam, &expertNode.operation); + + expertNode.inTensorIds = { + GetTensorIdx(tensorMap, "in_hiddenstatus"), + GetTensorIdx(tensorMap, "in_mlp_gateup_weight_expert"), + GetTensorIdx(tensorMap, "in_mlp_gateup_descale_expert"), + GetTensorIdx(tensorMap, "in_mlp_gateup_offset_expert"), + GetTensorIdx(tensorMap, "in_mlp_gateup_scale_expert"), + GetTensorIdx(tensorMap, "in_mlp_gateup_compress_idx_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_weight_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_descale_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_offset_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_scale_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_compress_idx_expert"), + GetTensorIdx(tensorMap, "in_expert_array"), + GetTensorIdx(tensorMap, "in_selected_experts"), + GetTensorIdx(tensorMap, "in_expert_weight"), + GetTensorIdx(tensorMap, "in_moe_idx"), + }; + expertNode.outTensorIds = {GetTensorIdx(tensorMap, "out_hiddenstates")}; + + ATB_SPEED_LOG_DEBUG("Expert Group calculation success"); + return atb::NO_ERROR; +} + + +atb::Status CreateDataPreparation(std::map &tensorMap, + const DynamicEpMoEParam ¶m, size_t &nodeId, atb::GraphParam &opGraph) +{ + auto &dataPreparationNode = opGraph.nodes.at(nodeId++); + atb_speed::common::DataPreparationParam dataPreparationParam; + dataPreparationParam.topk = param.topk; + dataPreparationParam.numOfExperts = param.numOfExperts; + dataPreparationParam.rank = param.moeEpRank; + dataPreparationParam.worldSize = param.moeEpSize; + dataPreparationParam.mixSharedRouting = param.mixSharedRouting; + atb_speed::common::CreateDataPreparationOperation(dataPreparationParam, &dataPreparationNode.operation); + dataPreparationNode.inTensorIds = {GetTensorIdx(tensorMap, "in_selected_experts"), + GetTensorIdx(tensorMap, "in_buffer_idx"), + GetTensorIdx(tensorMap, "in_one_hot"), + GetTensorIdx(tensorMap, "in_zero_hot")}; + dataPreparationNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx"), + GetTensorIdx(tensorMap, "intermediate_expert_shuffle_idx"), + GetTensorIdx(tensorMap, "intermediate_group_count")}; + dataPreparationNode.inTensorChunks.resize(dataPreparationNode.inTensorIds.size()); + ATB_SPEED_LOG_DEBUG("dataPreparationNode calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateAllToAllMeta(std::map &tensorMap, const DynamicEpMoEParam ¶m, + size_t &nodeId, atb::GraphParam &opGraph) +{ + auto &allToAllMetaNode = opGraph.nodes.at(nodeId++); + atb_speed::common::AllToAllMetaParam allToAllMetaParam; + allToAllMetaParam.topk = param.topk; + allToAllMetaParam.numOfExperts = param.numOfExperts; + allToAllMetaParam.rank = param.moeEpRank; + allToAllMetaParam.worldSize = param.moeEpSize; + allToAllMetaParam.isBF16 = param.isBF16; + atb_speed::common::CreateAllToAllMetaOperation(allToAllMetaParam, &allToAllMetaNode.operation); + allToAllMetaNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_group_count"), + GetTensorIdx(tensorMap, "in_moe_idx"), + GetTensorIdx(tensorMap, "in_zero_hot")}; + allToAllMetaNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_buffer_idx"), + GetTensorIdx(tensorMap, "intermediate_shuffle_weight"), + GetTensorIdx(tensorMap, "intermediate_valid_idx")}; + allToAllMetaNode.inTensorChunks.resize(allToAllMetaNode.inTensorIds.size()); + ATB_SPEED_LOG_DEBUG("allToAllMetaNode calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateShuffleIdxDE(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &gatherNode = opGraph.nodes.at(nodeId++); + atb::infer::GatherParam gatherParam; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatherParam, &gatherNode.operation)); + gatherNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx"), + GetTensorIdx(tensorMap, "intermediate_buffer_idx")}; + gatherNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx_1")}; + ATB_SPEED_LOG_DEBUG("CreateShuffleIdxDE"); + return atb::NO_ERROR; +} + +atb::Status CreateExpertShuffleIdx(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &gatherNode = opGraph.nodes.at(nodeId++); + atb::infer::GatherParam gatherParam; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatherParam, &gatherNode.operation)); + gatherNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_expert_shuffle_idx"), + GetTensorIdx(tensorMap, "intermediate_buffer_idx")}; + gatherNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_expert_shuffle_idx_1")}; + ATB_SPEED_LOG_DEBUG("CreateExpertShuffleIdx"); + return atb::NO_ERROR; +} + +atb::Status CreateAllToAllDispatch(std::map &tensorMap, const DynamicEpMoEParam ¶m, + size_t &nodeId, atb::GraphParam &opGraph) +{ + auto &allToAllDispatchNode = opGraph.nodes.at(nodeId++); + atb_speed::common::AllToAllDispatchParam allToAllDispatchParam; + + allToAllDispatchParam.topk = param.topk; + allToAllDispatchParam.numOfExperts = param.numOfExperts; + allToAllDispatchParam.backend = param.backend; + allToAllDispatchParam.hcclComm = param.hcclComm; + allToAllDispatchParam.hasMoeEp = param.hasMoeEp; + allToAllDispatchParam.moeEpRank = param.moeEpRank; + allToAllDispatchParam.moeEpSize = param.moeEpSize; + allToAllDispatchParam.moeEpRankTableFile = param.moeEpRankTableFile; + allToAllDispatchParam.moeEpDomain = param.moeEpDomain; + + allToAllDispatchParam.hasMlpTp = param.hasMlpTp; + allToAllDispatchParam.mlpTpRank = param.mlpTpRank; + allToAllDispatchParam.mlpTpSize = param.mlpTpSize; + allToAllDispatchParam.mlpTpRankTableFile = param.mlpTpRankTableFile; + allToAllDispatchParam.mlpTpDomain = param.mlpTpDomain; + + atb_speed::common::CreateAllToAllDispatchOperation(allToAllDispatchParam, &allToAllDispatchNode.operation); + allToAllDispatchNode.inTensorIds = { + GetTensorIdx(tensorMap, "in_hiddenstatus"), + GetTensorIdx(tensorMap, "in_selected_experts"), + GetTensorIdx(tensorMap, "in_expert_weight"), + GetTensorIdx(tensorMap, "intermediate_shuffle_idx_1"), + GetTensorIdx(tensorMap, "intermediate_expert_shuffle_idx_1"), + GetTensorIdx(tensorMap, "in_zero_hot"), + GetTensorIdx(tensorMap, "in_one_hot"), + }; + + allToAllDispatchNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_recv_hiddenstatus"), + GetTensorIdx(tensorMap, "intermediate_recv_selected_experts"), + GetTensorIdx(tensorMap, "intermediate_experts_weight"), + }; + + allToAllDispatchNode.inTensorChunks.resize(allToAllDispatchNode.inTensorIds.size()); + ATB_SPEED_LOG_DEBUG("allToAllDispatchNode calculation success"); + return atb::NO_ERROR; +} + +atb::Status SetMoeMlpParam(atb_speed::common::MoeMlpParam &mlpExpertParam, const DynamicEpMoEParam ¶m) +{ + if (param.hasMoeEp && param.isDynamicEp && !param.enableMoeDistribute) { + mlpExpertParam.topk = 1; + } else { + mlpExpertParam.topk = param.topk; + mlpExpertParam.scaledTopk = param.scaledTopk; + mlpExpertParam.enableInitRoutingCutoff = param.enableInitRoutingCutoff; + } + mlpExpertParam.numOfDeviceExperts = param.numOfDeviceExperts; + mlpExpertParam.hasMoeEp = param.hasMoeEp; + mlpExpertParam.deviceExpert = param.deviceExpert; + mlpExpertParam.expertParallelDegree = param.expertParallelDegree; + mlpExpertParam.transpose = param.transpose; + mlpExpertParam.numOfExperts = param.numOfExperts; + mlpExpertParam.supportSwiGLU = param.supportSwiGLU; + mlpExpertParam.moeLinearQuantType = param.moeLinearQuantType; + mlpExpertParam.packQuantType = param.packQuantType; + mlpExpertParam.denseQuantType = param.denseQuantType; + mlpExpertParam.isBF16 = param.isBF16; + mlpExpertParam.gateUpTransposeB = param.gateUpTransposeB; + mlpExpertParam.downTransposeB = param.downTransposeB; + mlpExpertParam.enableFusedRouting = param.enableFusedRouting; + mlpExpertParam.enableInitQuant = param.enableInitQuant; + mlpExpertParam.enableSwigluQuant = param.enableSwigluQuant; + mlpExpertParam.enableAtlasGMMFused = param.enableAtlasGMMFused; + mlpExpertParam.quantGroupSize = param.quantGroupSize; + mlpExpertParam.enableGMMSwigluQuant = param.enableGMMSwigluQuant; + mlpExpertParam.enableCVOverlap = param.enableCVOverlap; + mlpExpertParam.backend = param.backend; + mlpExpertParam.hasMoeEp = param.hasMoeEp; + mlpExpertParam.moeEpRank = param.moeEpRank; + mlpExpertParam.moeEpSize = param.moeEpSize; + mlpExpertParam.moeEpDomain = param.moeEpDomain; + mlpExpertParam.maxDecodeDpTokenSize = param.maxDecodeDpTokenSize; + if (param.expertParallelDegree == 1) { + mlpExpertParam.shiftedTopK = true; + } + mlpExpertParam.enableMoeDistribute = param.enableMoeDistribute && param.isDynamicEp; + mlpExpertParam.enableExpertCumSumOutput = param.enableExpertCumSumOutput; + mlpExpertParam.enableGatingDp = param.enableGatingDp; + mlpExpertParam.enableDispatchCombineV2 = param.enableDispatchCombineV2; + mlpExpertParam.numDanglingSharedExperts = param.numDanglingSharedExperts; + mlpExpertParam.numOfRedundantExpert = param.numOfRedundantExpert; + return atb::NO_ERROR; +} + +atb::Status CreateMoeMlp(std::map &tensorMap, const DynamicEpMoEParam ¶m, + size_t &nodeId, atb::GraphParam &opGraph) +{ + auto &expertNode = opGraph.nodes.at(nodeId++); + atb_speed::common::MoeMlpParam mlpExpertParam; + SetMoeMlpParam(mlpExpertParam, param); + atb_speed::common::CreateMoeMlpOperation(mlpExpertParam, &expertNode.operation); + expertNode.outTensorIds = {GetTensorIdx(tensorMap, + (param.hasMoeEp && param.isDynamicEp && !param.enableMoeDistribute) ? \ + "intermediate_moe_output" : "out_hiddenstates")}; + if (param.enableExpertCumSumOutput) { + expertNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "out_gmm_cumsum_list")); + } + + expertNode.inTensorIds = {GetTensorIdx(tensorMap, + (param.hasMoeEp && param.isDynamicEp && !param.enableMoeDistribute) ? \ + "intermediate_recv_hiddenstatus" : "in_hiddenstatus"), + GetTensorIdx(tensorMap, "in_mlp_gateup_weight_expert"), + GetTensorIdx(tensorMap, "in_mlp_gateup_bias_expert"), + GetTensorIdx(tensorMap, "in_mlp_gateup_descale_expert"), + GetTensorIdx(tensorMap, "in_mlp_gateup_offset_expert"), + GetTensorIdx(tensorMap, "in_mlp_gateup_scale_expert"), + GetTensorIdx(tensorMap, "in_mlp_gateup_compress_idx_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_weight_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_bias_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_descale_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_offset_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_scale_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_compress_idx_expert"), + GetTensorIdx(tensorMap, "in_expert_array"), + GetTensorIdx(tensorMap, (param.hasMoeEp && param.isDynamicEp && !param.enableMoeDistribute) ? \ + "intermediate_recv_selected_experts" : "in_selected_experts"), + GetTensorIdx(tensorMap, (param.hasMoeEp && param.isDynamicEp && !param.enableMoeDistribute) ? \ + "in_expert_array" : "in_expert_weight")}; + if (param.hasMoeEp) { + expertNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_zero_hot")); + expertNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_start_expert_idx")); + expertNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_device_expert_count")); + expertNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_padding_idx")); + } + ATB_SPEED_LOG_DEBUG("Expert Group calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateShuffleIdxDE2(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &gatherNode = opGraph.nodes.at(nodeId++); + atb::infer::GatherParam gatherParam; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatherParam, &gatherNode.operation)); + + gatherNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx"), + GetTensorIdx(tensorMap, "intermediate_buffer_idx")}; + gatherNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_idx_2")}; + + ATB_SPEED_LOG_DEBUG("CreateShuffleIdxDE"); + return atb::NO_ERROR; +} + +atb::Status CreateShuffleWeight(std::map &tensorMap, + size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &node = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam elewiseParam; + elewiseParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(elewiseParam, &node.operation)); + + node.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_experts_weight"), + GetTensorIdx(tensorMap, "intermediate_shuffle_weight")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_shuffle_weight")}; + + node.inTensorReshapeFuncs.resize(node.inTensorIds.size()); + node.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + ATB_SPEED_LOG_DEBUG("CreateShuffleWeight"); + return atb::NO_ERROR; +} + +atb::Status CreateAllToAllCollect(std::map &tensorMap, const DynamicEpMoEParam ¶m, + size_t &nodeId, atb::GraphParam &opGraph) +{ + auto &allToAllCollectNode = opGraph.nodes.at(nodeId++); + atb_speed::common::AllToAllCollectParam allToAllCollectParam; + + allToAllCollectParam.topk = param.topk; + allToAllCollectParam.numOfExperts = param.numOfExperts; + + allToAllCollectParam.backend = param.backend; + allToAllCollectParam.hcclComm = param.hcclComm; + allToAllCollectParam.hasMoeEp = param.hasMoeEp; + allToAllCollectParam.moeEpRank = param.moeEpRank; + allToAllCollectParam.moeEpSize = param.moeEpSize; + allToAllCollectParam.moeEpRankTableFile = param.moeEpRankTableFile; + allToAllCollectParam.moeEpDomain = param.moeEpDomain; + + allToAllCollectParam.hasMlpTp = param.hasMlpTp; + allToAllCollectParam.mlpTpRank = param.mlpTpRank; + allToAllCollectParam.mlpTpSize = param.mlpTpSize; + allToAllCollectParam.mlpTpRankTableFile = param.mlpTpRankTableFile; + allToAllCollectParam.mlpTpDomain = param.mlpTpDomain; + atb_speed::common::CreateAllToAllCollectOperation(allToAllCollectParam, &allToAllCollectNode.operation); + allToAllCollectNode.inTensorIds = {GetTensorIdx(tensorMap, "in_hiddenstatus"), + GetTensorIdx(tensorMap, "intermediate_moe_output"), + GetTensorIdx(tensorMap, "intermediate_shuffle_weight"), + GetTensorIdx(tensorMap, "intermediate_shuffle_idx_1"), + GetTensorIdx(tensorMap, "intermediate_valid_idx")}; + allToAllCollectNode.outTensorIds = {GetTensorIdx(tensorMap, "out_hiddenstates")}; + allToAllCollectNode.inTensorChunks.resize(allToAllCollectNode.inTensorIds.size()); + ATB_SPEED_LOG_DEBUG("allToAllCollectNode calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateDynamicEpMoEOperation(const DynamicEpMoEParam ¶m, atb::Operation **operation) +{ + ATB_SPEED_LOG_DEBUG("CreateDynamicEpMoEOperation Start"); + atb::GraphParam opGraph; + opGraph.name = "DynamicEpMoE"; + std::map tensorMap = ConstructDynamicEpTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + + uint64_t nodeCount = 1; + if (param.hasMoeEp && param.isDynamicEp && !param.enableMoeDistribute && !param.enableLcocAll2All) { + nodeCount = 9; // 9: ep level2 时使用9个节点 + } + opGraph.nodes.resize(nodeCount); + + size_t nodeId = 0; + if (param.enableLcocAll2All && param.isDynamicEp) { + // alltoall GMM operation + /* + 1. InitRoutingQuant + 2. Allgather Matrix + 3. AlltoallGMM + 4 DequantSwigluQuant + 5 GMMAlltoall + 6 MoeTokenUnpermute + */ + CHECK_OPERATION_STATUS_RETURN(CreateFusedAllToAllMlp(tensorMap, param, nodeId, opGraph)); + } else if (param.hasMoeEp && param.isDynamicEp && !param.enableMoeDistribute) { + CHECK_OPERATION_STATUS_RETURN(CreateDataPreparation(tensorMap, param, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateAllToAllMeta(tensorMap, param, nodeId, opGraph)); + + CHECK_OPERATION_STATUS_RETURN(CreateShuffleIdxDE(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateExpertShuffleIdx(tensorMap, nodeId, opGraph)); + + CHECK_OPERATION_STATUS_RETURN(CreateAllToAllDispatch(tensorMap, param, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateMoeMlp(tensorMap, param, nodeId, opGraph)); + + CHECK_OPERATION_STATUS_RETURN(CreateShuffleWeight(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateShuffleIdxDE2(tensorMap, nodeId, opGraph)); // 内存复写 + CHECK_OPERATION_STATUS_RETURN(CreateAllToAllCollect(tensorMap, param, nodeId, opGraph)); + } else { + CHECK_OPERATION_STATUS_RETURN(CreateMoeMlp(tensorMap, param, nodeId, opGraph)); + } + + opGraph.inferShapeFunc = [=] (const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + if (param.enableExpertCumSumOutput) { + outTensorDescs.at(1) = atb::TensorDesc{}; + outTensorDescs.at(1).format = ACL_FORMAT_ND; + outTensorDescs.at(1).shape.dimNum = 1; + outTensorDescs.at(1).dtype = ACL_INT64; + outTensorDescs.at(1).shape.dims[0] = param.numOfDeviceExperts; + } + return atb::NO_ERROR; + }; + + CREATE_OPERATION(opGraph, operation); + ATB_SPEED_LOG_DEBUG("CreateDynamicEpMoEOperation seccess"); + return atb::NO_ERROR; +} +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/dynamic_ep_moe.h b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/dynamic_ep_moe.h new file mode 100644 index 00000000..7430d66c --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/dynamic_ep_moe.h @@ -0,0 +1,91 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +* +* * Licensed under the Apache License, Version 2.0 (the "License"); +* * you may not use this file except in compliance with the License. +* * You may obtain a copy of the License at +* * +* * http://www.apache.org/licenses/LICENSE-2.0 +* * +* * Unless required by applicable law or agreed to in writing, software +* * distributed under the License is distributed on an "AS IS" BASIS, +* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* * See the License for the specific language governing permissions and +* * limitations under the License. +* */ +#ifndef ATB_SPEED_MODELS_DYNAMIC_EP_MOE_OPERATION_H +#define ATB_SPEED_MODELS_DYNAMIC_EP_MOE_OPERATION_H +#include +#include +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/log.h" +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/linear/linear_parallel.h" +#include "operations/fusion/norm/norm_linear.h" + +namespace atb_speed { +namespace common { +struct DynamicEpMoEParam { + bool transpose = true; + bool supportSwiGLU = true; + int32_t topk = 2; + int32_t scaledTopk = -1; /// 非deepseek模型默认不启用scaledTopk特性 + bool enableInitRoutingCutoff = false; /// A flag indicating whether to use scaled topk option + int gmmQuantType = 0; + uint32_t numOfExperts = 8; + uint32_t numOfDeviceExperts = 8; + std::vector moeLinearQuantType = {}; + bool hasBias = false; + bool isBF16 = false; + bool gateUpTransposeB = false; + bool downTransposeB = false; + bool isDynamicEp = false; + int packQuantType = atb_speed::common::PackQuantType::ALL_FP; + int denseQuantType = atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED; + int quantGroupSize = 0; /// Group size of per-group quantization + + std::vector deviceExpert = {0, 1, 2, 3, 4, 5, 6, 7}; + int expertParallelDegree = 0; + bool enableFusedRouting = false; + bool enableGMMSwigluQuant = false; + bool enableInitQuant = false; + bool enableSwigluQuant = false; + bool enableAtlasGMMFused = false; + bool enableFusedTopk = false; + bool enableDispatchCombineV2 = false; /// A flag indicating whether to use dispatch_v2 and combine_v2 + + std::string backend = "hccl"; + HcclComm hcclComm = nullptr; + + bool hasMoeEp = false; + int moeEpRank = 0; + int moeEpSize = 1; + int maxDecodeDpTokenSize = 0; + std::string moeEpDomain = ""; + std::string moeEpRankTableFile = ""; + + bool hasMlpTp = false; + int mlpTpRank = 0; + int mlpTpSize = 1; + std::string mlpTpDomain = ""; + std::string mlpTpRankTableFile = ""; + bool enableCVOverlap = false; /// A flag indicating whether the model use cube and vector parallel + bool enableMoeDistribute = false; + bool enableExpertCumSumOutput = false; + bool enableGatingDp = false; + int64_t numDanglingSharedExperts = 0; + uint32_t numOfRedundantExpert = 0; + + bool enableLcocAll2All = false; + std::string routingMethod = ""; + + std::string lcclMoeEpDomain = ""; + HcclComm lcclMoeEpHcclComm = nullptr; + + bool mixSharedRouting = false; +}; + +atb::Status CreateDynamicEpMoEOperation(const DynamicEpMoEParam ¶m, atb::Operation **operation); +} +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/expert_filter.cpp b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/expert_filter.cpp new file mode 100644 index 00000000..fefed933 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/expert_filter.cpp @@ -0,0 +1,264 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "expert_filter.h" +#include +#include +#include "operations/fusion/utils.h" +#include "operations/aclnn/ops/inplacemasked_filltensor_operation.h" + + +namespace atb_speed { +namespace common { +std::map> GetExpertFilterInTensorCandidates() +{ + std::map> expertFilterInTensorCandidates = { + {"default", { + "in_selected_experts", "in_expert_weight", "in_start_expert_idx", + "in_device_expert_count", "in_zero_hot"} + }, + }; + return expertFilterInTensorCandidates; +} + + +std::map> GetExpertFilterInterTensorCandidates() +{ + std::map> expertFilterInterTensorCandidates = { + {"default", { + "intermediate_selected_experts_int64", "intermediate_selected_experts_mask", + "intermediate_selected_experts_shifted_int32", "intermediate_selected_experts_shifted_int64", + "intermediate_selected_experts_mask_1", "intermediate_zero_hot_int64"} + }, + {"shifted", { + "intermediate_selected_experts_int64", "intermediate_selected_experts_mask_1"} + }, + }; + return expertFilterInterTensorCandidates; +} + +std::map> GetExpertFilterOutTensorCandidates() +{ + std::map> expertFilterOutTensorCandidates = { + {"default", { + "out_selected_experts", "out_expert_weight"} + }, + {"shifted", { + "out_expert_weight"} + }, + }; + return expertFilterOutTensorCandidates; +} + +std::map ConstructTensorMap(const ExpertFilterParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &interTensorNum) +{ + auto expertFilterInTensorCandidates = GetExpertFilterInTensorCandidates(); + auto expertFilterInterTensorCandidates = GetExpertFilterInterTensorCandidates(); + auto expertFilterOutTensorCandidates = GetExpertFilterOutTensorCandidates(); + + std::vector inTensorList = {}; + std::vector interTensorList = {}; + std::vector outTensorList = {}; + + AddTensorToList(expertFilterInTensorCandidates, "default", inTensorList); + if (param.shiftedTopK && !param.enableGatingDp) { + AddTensorToList(expertFilterInterTensorCandidates, "shifted", interTensorList); + AddTensorToList(expertFilterOutTensorCandidates, "shifted", outTensorList); + } else { + AddTensorToList(expertFilterInterTensorCandidates, "default", interTensorList); + AddTensorToList(expertFilterOutTensorCandidates, "default", outTensorList); + } + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + interTensorNum = interTensorList.size(); + return GetTensorMap(inTensorList, outTensorList, interTensorList); +} + +atb::Status CreateSelectedExpertInt64(std::map &tensorMap, + size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &castNode = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam castParam; + castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castParam.outTensorType = ACL_INT64; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castParam, &castNode.operation)); + castNode.inTensorIds = {GetTensorIdx(tensorMap, "in_selected_experts")}; + castNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_selected_experts_int64")}; + return atb::NO_ERROR; +} + +atb::Status CreateSelectedExpertSub(std::map &tensorMap, + size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &subNode = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam subParam; + subParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_SUB; + CreateOperation(subParam, &subNode.operation); + CHECK_OPERATION_STATUS_RETURN(CreateOperation(subParam, &subNode.operation)); + subNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_selected_experts_int64"), + GetTensorIdx(tensorMap, "in_start_expert_idx")}; + subNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_selected_experts_shifted_int64")}; + + return atb::NO_ERROR; +} + +atb::Status CreateZeroHotInt64(std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &castNode = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam castParam; + castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castParam.outTensorType = ACL_INT64; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castParam, &castNode.operation)); + castNode.inTensorIds = {GetTensorIdx(tensorMap, "in_zero_hot")}; + castNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_zero_hot_int64")}; + return atb::NO_ERROR; +} + +atb::Status CreateSelectedExpertMask(std::map &tensorMap, + size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &node = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam lessParam; + lessParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_LESS; + CreateOperation(lessParam, &node.operation); + node.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_selected_experts_shifted_int64"), + GetTensorIdx(tensorMap, "intermediate_zero_hot_int64")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_selected_experts_mask")}; + return atb::NO_ERROR; +} + +atb::Status CreateSelectedExpertInt32(std::map &tensorMap, + size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &castNode = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam castParam; + castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castParam.outTensorType = ACL_INT32; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castParam, &castNode.operation)); + castNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_selected_experts_shifted_int64")}; + castNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_selected_experts_shifted_int32")}; + return atb::NO_ERROR; +} + +atb::Status CreateOutSelectedExpert(std::map &tensorMap, + const ExpertFilterParam ¶m, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &node = opGraph.nodes.at(nodeId++); + atb::infer::FillParam fillParam; + fillParam.withMask = true; + fillParam.value.resize(1); + fillParam.value[0] = param.numOfExperts; + CreateOperation(fillParam, &node.operation); + node.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_selected_experts_shifted_int32"), + GetTensorIdx(tensorMap, "intermediate_selected_experts_mask")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "out_selected_experts")}; + return atb::NO_ERROR; +} + +atb::Status CreateSelectedExpertMask1( + std::map &tensorMap, const ExpertFilterParam ¶m, size_t &nodeId, + atb::GraphParam &opGraph) +{ + atb::Node &node = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam lessParam; + lessParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_GREATER; + CreateOperation(lessParam, &node.operation); + node.inTensorIds = { + GetTensorIdx(tensorMap, (param.shiftedTopK && !param.enableGatingDp) ? \ + "intermediate_selected_experts_int64" : "intermediate_selected_experts_shifted_int64"), + GetTensorIdx(tensorMap, "in_device_expert_count")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_selected_experts_mask_1")}; + return atb::NO_ERROR; +} + +atb::Status CreateExpertWeightFilter( + std::map &tensorMap, size_t &nodeId, atb::GraphParam &opGraph) +{ + atb::Node &node = opGraph.nodes.at(nodeId++); + atb::infer::FillParam param; + param.withMask = true; + param.value.resize(1); + param.value[0] = 0; + + CreateOperation(param, &node.operation); + + node.inTensorIds = {GetTensorIdx(tensorMap, "in_expert_weight"), + GetTensorIdx(tensorMap, "intermediate_selected_experts_mask")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "out_expert_weight")}; + + return atb::NO_ERROR; +} + +atb::Status CreateOutExpertWeightFilter( + std::map &tensorMap, const ExpertFilterParam ¶m, size_t &nodeId, + atb::GraphParam &opGraph) +{ + atb::Node &node = opGraph.nodes.at(nodeId++); + atb_speed::common::InplaceMaskedFillTensorParam inplaceMaskedFillTensorParam; + inplaceMaskedFillTensorParam.value = 0; + inplaceMaskedFillTensorParam.outDataType = param.isBF16 ? ACL_BF16 : ACL_FLOAT16; + node.operation = new atb_speed::common::InplaceMaskedFillTensorOperation("MaskFillNode", + inplaceMaskedFillTensorParam); + node.inTensorIds = {GetTensorIdx(tensorMap, (param.shiftedTopK && !param.enableGatingDp) ? \ + "in_expert_weight" : "out_expert_weight"), + GetTensorIdx(tensorMap, "intermediate_selected_experts_mask_1")}; + node.outTensorIds = {GetTensorIdx(tensorMap, "out_expert_weight")}; + + return atb::NO_ERROR; +} + +atb::Status CreateExpertFilterOperation(const ExpertFilterParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + opGraph.name = "ExpertFilter"; + std::map tensorMap = ConstructTensorMap(param, + opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + uint64_t nodeCount = 9; // 9: default node count + if (param.shiftedTopK && !param.enableGatingDp) { + nodeCount = 3; // 3: node count + } + opGraph.nodes.resize(nodeCount); + size_t nodeId = 0; + if (!param.shiftedTopK || param.enableGatingDp) { + CHECK_OPERATION_STATUS_RETURN(CreateSelectedExpertInt64(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateSelectedExpertSub(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateZeroHotInt64(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateSelectedExpertMask(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateSelectedExpertInt32(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateOutSelectedExpert(tensorMap, param, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateSelectedExpertMask1(tensorMap, param, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateExpertWeightFilter(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateOutExpertWeightFilter(tensorMap, param, nodeId, opGraph)); + } else { + CHECK_OPERATION_STATUS_RETURN(CreateSelectedExpertInt64(tensorMap, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateSelectedExpertMask1(tensorMap, param, nodeId, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateOutExpertWeightFilter(tensorMap, param, nodeId, opGraph)); + } + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + if (param.shiftedTopK && !param.enableGatingDp) { + outTensorDescs.at(0) = inTensorDescs.at(1); + } else { + outTensorDescs.at(0) = inTensorDescs.at(0); + outTensorDescs.at(1) = inTensorDescs.at(1); + } + return atb::NO_ERROR; + }; + CREATE_OPERATION(opGraph, operation); + return atb::NO_ERROR; +} +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/expert_filter.h b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/expert_filter.h new file mode 100644 index 00000000..0298b156 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/expert_filter.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_EXPERT_FILTER_OPERATION_H +#define ATB_SPEED_MODELS_EXPERT_FILTER_OPERATION_H +#include +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/log.h" + +namespace atb_speed { +namespace common { +struct ExpertFilterParam { + bool shiftedTopK = true; + bool isBF16 = false; + bool enableGatingDp = false; + long unsigned int numOfExperts = 8; + std::vector deviceExpert = {0, 1, 2, 3, 4, 5, 6, 7}; +}; + +atb::Status CreateExpertFilterOperation(const ExpertFilterParam ¶m, atb::Operation **operation); +} +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/fused_alltoall_gmm.cpp b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/fused_alltoall_gmm.cpp new file mode 100644 index 00000000..c6090fdd --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/fused_alltoall_gmm.cpp @@ -0,0 +1,271 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +* +* * Licensed under the Apache License, Version 2.0 (the "License"); +* * you may not use this file except in compliance with the License. +* * You may obtain a copy of the License at +* * +* * http://www.apache.org/licenses/LICENSE-2.0 +* * +* * Unless required by applicable law or agreed to in writing, software +* * distributed under the License is distributed on an "AS IS" BASIS, +* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* * See the License for the specific language governing permissions and +* * limitations under the License. +* */ +#include "fused_alltoall_gmm.h" +#include +#include +#include "operations/aclnn/ops/grouped_matmul_operation.h" +#include "operations/fusion/utils.h" +#include "operations/aclnn/ops/moetoken_unpermute_operation.h" +#include "operations/aclnn/ops/moe_init_routing_quant_operation.h" +#include "operations/aclnn/ops/dequant_swiglu_quant_operation.h" +#include "operations/aclnn/ops/dynamic_quant_operation.h" + +namespace atb_speed { +namespace common { + +std::map> GetAll2AllMatmulInTensorCandidates() +{ + std::map> all2AllMatmulInTensorCandidates = { + {"default", { + "in_hiddenstatus", "in_mlp_gateup_weight_expert", "in_mlp_gateup_descale_expert", + "in_mlp_gateup_offset_expert", "in_mlp_gateup_scale_expert", "in_mlp_gateup_compress_idx_expert", + "in_mlp_down_weight_expert", "in_mlp_down_descale_expert", "in_mlp_down_offset_expert", + "in_mlp_down_scale_expert", "in_mlp_down_compress_idx_expert", "in_expert_array", "in_selected_experts", + "in_expert_weight", "in_moe_idx" + } + }, + }; + return all2AllMatmulInTensorCandidates; +} + +std::map> GetAll2AllMatmulInterTensorCandidates() +{ + std::map> all2AllMatmulInterTensorCandidates = { + {"default", { + "intermediate_hiddenstates", + "intermediate_idx", + "intermediate_group_list", + "intermediate_dynamic_scale", + "intermediate_group_list_full", + "intermediate_gate_up_out", + "intermediate_quant_swish_out", + "intermediate_swish_out_scale", + "intermediate_mlp_out", + "intermediate_tokens_before_capacity" + } + }, + }; + return all2AllMatmulInterTensorCandidates; +} + +std::map> GetAll2AllMatmulOutTensorCandidates() +{ + std::map> all2AllMatmulOutTensorCandidates = { + {"default", { + "out_hiddenstates"} + }, + }; + return all2AllMatmulOutTensorCandidates; +} + +std::map ConstructDynamicEpTensorMap( + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &interTensorNum) +{ + auto all2AllMatmulInTensorCandidates = GetAll2AllMatmulInTensorCandidates(); + auto all2AllMatmulInterTensorCandidates = GetAll2AllMatmulInterTensorCandidates(); + auto all2AllMatmulOutTensorCandidates = GetAll2AllMatmulOutTensorCandidates(); + std::vector inTensorList = {}; + std::vector interTensorList = {}; + std::vector outTensorList = {}; + AddTensorToList(all2AllMatmulInTensorCandidates, "default", inTensorList); + AddTensorToList(all2AllMatmulInterTensorCandidates, "default", interTensorList); + AddTensorToList(all2AllMatmulOutTensorCandidates, "default", outTensorList); + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + interTensorNum = interTensorList.size(); + return GetTensorMap(inTensorList, outTensorList, interTensorList); +} + + +atb::Status CreateInitRoutingQuant( + std::map &tensorMap, const All2AllMatmulParam ¶m, + atb::GraphParam &opGraph, size_t &nodeId) +{ + atb::Node &initRoutingNode = opGraph.nodes.at(nodeId++); + atb_speed::common::MoeInitRoutingQuantParam initRoutingParam; + initRoutingParam.topkNum = param.topk; + initRoutingParam.scaledTopk = + (param.scaledTopk == atb_speed::common::DEFAULT_TOPK_SCALE) + ? param.topk : param.scaledTopk; + initRoutingParam.expertNum = param.numOfExperts; + initRoutingParam.expertTokensCoutOrCumsumFlag = NUM2; // 采用分别计数,不用expert累加和 + initRoutingNode.operation = new atb_speed::common::MoeInitRoutingQuantOperation("MoeInitRoutingQuantOperation", + initRoutingParam); + initRoutingNode.inTensorIds = { + GetTensorIdx(tensorMap, "in_hiddenstatus"), + GetTensorIdx(tensorMap, "in_selected_experts") + }; + initRoutingNode.outTensorIds = { + GetTensorIdx(tensorMap, "intermediate_hiddenstates"), + GetTensorIdx(tensorMap, "intermediate_idx"), + GetTensorIdx(tensorMap, "intermediate_group_list"), + GetTensorIdx(tensorMap, "intermediate_tokens_before_capacity"), + GetTensorIdx(tensorMap, "intermediate_dynamic_scale") + }; + ATB_SPEED_LOG_DEBUG("FusedAlltoallGMM Create InitRouting success"); + return atb::NO_ERROR; +} + +atb::Status CreateGrouplistAllGather( + std::map &tensorMap, const All2AllMatmulParam ¶m, + atb::GraphParam &opGraph, size_t &nodeId) +{ + atb::Node &allGatherGrouplistNode = opGraph.nodes.at(nodeId++); + atb::infer::AllGatherParam allGatherParam; + allGatherParam.rank = param.moeEpRank; + allGatherParam.rankSize = param.moeEpSize; + allGatherParam.backend = "lccl"; + allGatherParam.rankTableFile = ""; + allGatherParam.commDomain = param.lcclMoeEpDomain; + allGatherGrouplistNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_group_list")}; + allGatherGrouplistNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_group_list_full")}; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(allGatherParam, &allGatherGrouplistNode.operation)); + ATB_SPEED_LOG_DEBUG("FusedAlltoallGMM Create AllGather success"); + return atb::NO_ERROR; +} + + +atb::Status CreateAll2AllMatmul(std::map &tensorMap, const All2AllMatmulParam ¶m, + atb::GraphParam &opGraph, size_t &nodeId) +{ + atb::Node &allToAllMatmulNode = opGraph.nodes.at(nodeId++); + atb::infer::LinearParallelParam allToAllMatmulParam; + + allToAllMatmulParam.rank = param.moeEpRank; + allToAllMatmulParam.rankSize = param.moeEpSize; + allToAllMatmulParam.backend = "lcoc"; + allToAllMatmulParam.rankTableFile = ""; + allToAllMatmulParam.commDomain = ""; + allToAllMatmulParam.transWeight = param.gateUpTransposeB; + allToAllMatmulParam.type = atb::infer::LinearParallelParam::ParallelType::ALLTOALLVC_ALL_GATHER_GMM; + allToAllMatmulParam.quantType = atb::infer::LinearParallelParam::QuantType::QUANT_TYPE_PER_TOKEN; + allToAllMatmulParam.moeInfo.epSize = param.moeEpSize; + allToAllMatmulParam.moeInfo.tpSize = 1; + allToAllMatmulParam.moeInfo.localExpertNums = param.numOfDeviceExperts; + allToAllMatmulParam.outDataType = aclDataType::ACL_FLOAT16; + + CHECK_OPERATION_STATUS_RETURN(CreateOperation(allToAllMatmulParam, &allToAllMatmulNode.operation)); + + allToAllMatmulNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_hiddenstates"), + GetTensorIdx(tensorMap, "in_mlp_gateup_weight_expert"), + GetTensorIdx(tensorMap, "in_mlp_gateup_scale_expert"), // per channel + GetTensorIdx(tensorMap, "intermediate_dynamic_scale"), // per token + GetTensorIdx(tensorMap, "intermediate_group_list_full"), // expert_per_token_matrix [ep, 256] + GetTensorIdx(tensorMap, "in_moe_idx") + }; + allToAllMatmulNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_gate_up_out")}; + ATB_SPEED_LOG_DEBUG("FusedAlltoallGMM Create AllGather success"); + return atb::NO_ERROR; +} + + +atb::Status CreateSwigluQuant(std::map &tensorMap, atb::GraphParam &opGraph, size_t &nodeId) +{ + atb::Node &swigluQuantNode = opGraph.nodes.at(nodeId++); + AclNNDequantSwigluQuantParam aclnnParam; + aclnnParam.activateLeft = true; + aclnnParam.quantMode = "dynamic"; + aclnnParam.inTensorsNum = 1; + swigluQuantNode.operation = new atb_speed::common::DequantSwigluQuantOperation("swigluQuantNode", aclnnParam); + swigluQuantNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_gate_up_out")}; + swigluQuantNode.outTensorIds = { + GetTensorIdx(tensorMap, "intermediate_quant_swish_out"), + GetTensorIdx(tensorMap, "intermediate_swish_out_scale") + }; + ATB_SPEED_LOG_DEBUG("FusedAlltoallGMM Create SwigluQuant success"); + return atb::NO_ERROR; +} + +atb::Status CreateMatmulAll2All(std::map &tensorMap, + const All2AllMatmulParam ¶m, atb::GraphParam &opGraph, size_t &nodeId) +{ + atb::Node &matmulAllToAllNode = opGraph.nodes.at(nodeId++); + atb::infer::LinearParallelParam matmulAllToAllParam; + + matmulAllToAllParam.rank = param.moeEpRank; + matmulAllToAllParam.rankSize = param.moeEpSize; + matmulAllToAllParam.backend = "lcoc"; + matmulAllToAllParam.rankTableFile = ""; + matmulAllToAllParam.commDomain = ""; + matmulAllToAllParam.transWeight = param.downTransposeB; + matmulAllToAllParam.type = atb::infer::LinearParallelParam::ParallelType::GMM_REDUCE_SCATTER_ALLTOALLVC; + matmulAllToAllParam.quantType = atb::infer::LinearParallelParam::QuantType::QUANT_TYPE_PER_TOKEN; + matmulAllToAllParam.moeInfo.localExpertNums = param.numOfDeviceExperts; + matmulAllToAllParam.moeInfo.tpSize = 1; + matmulAllToAllParam.moeInfo.epSize = param.moeEpSize; + matmulAllToAllParam.outDataType = aclDataType::ACL_FLOAT16; + + CHECK_OPERATION_STATUS_RETURN(CreateOperation(matmulAllToAllParam, &matmulAllToAllNode.operation)); + + matmulAllToAllNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_quant_swish_out"), + GetTensorIdx(tensorMap, "in_mlp_down_weight_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_scale_expert"), + GetTensorIdx(tensorMap, "intermediate_swish_out_scale"), + GetTensorIdx(tensorMap, "intermediate_group_list_full"), + GetTensorIdx(tensorMap, "intermediate_idx") + }; + matmulAllToAllNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_mlp_out")}; + + return atb::NO_ERROR; +} + +atb::Status CreateMoeTokenUnpermute( + std::map &tensorMap, atb::GraphParam &opGraph, size_t &nodeId) +{ + atb::Node &unpermuteNode = opGraph.nodes.at(nodeId++); + unpermuteNode.operation = new atb_speed::common::MoeTokenUnpermuteOperation("MoeTokenUnpermuteNode"); + unpermuteNode.outTensorIds = {GetTensorIdx(tensorMap, "out_hiddenstates")}; + unpermuteNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_mlp_out"), + GetTensorIdx(tensorMap, "intermediate_idx"), + GetTensorIdx(tensorMap, "in_expert_weight")}; + + ATB_SPEED_LOG_DEBUG("UnpermuteNode calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateAll2AllMatmulOperation(const All2AllMatmulParam ¶m, atb::Operation **operation) +{ + ATB_SPEED_LOG_DEBUG("CreateAll2AllMatmulOperation Start"); + atb::GraphParam opGraph; + opGraph.name = "All2AllMatmul"; + std::map tensorMap = ConstructDynamicEpTensorMap( + opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + + uint64_t nodeCount = 6; + opGraph.nodes.resize(nodeCount); + + size_t nodeId = 0; + CHECK_OPERATION_STATUS_RETURN(CreateInitRoutingQuant(tensorMap, param, opGraph, nodeId)); + CHECK_OPERATION_STATUS_RETURN(CreateGrouplistAllGather(tensorMap, param, opGraph, nodeId)); + CHECK_OPERATION_STATUS_RETURN(CreateAll2AllMatmul(tensorMap, param, opGraph, nodeId)); + CHECK_OPERATION_STATUS_RETURN(CreateSwigluQuant(tensorMap, opGraph, nodeId)); + CHECK_OPERATION_STATUS_RETURN(CreateMatmulAll2All(tensorMap, param, opGraph, nodeId)); + CHECK_OPERATION_STATUS_RETURN(CreateMoeTokenUnpermute(tensorMap, opGraph, nodeId)); + + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + return atb::NO_ERROR; + }; + + CREATE_OPERATION(opGraph, operation); + ATB_SPEED_LOG_DEBUG("CreateAll2AllMatmulOperation seccess"); + return atb::NO_ERROR; +} +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/ep/fused_alltoall_gmm.h b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/fused_alltoall_gmm.h new file mode 100644 index 00000000..9e142512 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/ep/fused_alltoall_gmm.h @@ -0,0 +1,45 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +* +* * Licensed under the Apache License, Version 2.0 (the "License"); +* * you may not use this file except in compliance with the License. +* * You may obtain a copy of the License at +* * +* * http://www.apache.org/licenses/LICENSE-2.0 +* * +* * Unless required by applicable law or agreed to in writing, software +* * distributed under the License is distributed on an "AS IS" BASIS, +* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* * See the License for the specific language governing permissions and +* * limitations under the License. +* */ +#ifndef ATB_SPEED_MODELS_ALL2ALL_MATMUL_OPERATION_H +#define ATB_SPEED_MODELS_ALL2ALL_MATMUL_OPERATION_H +#include +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/log.h" +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/linear/linear_parallel.h" +#include "operations/fusion/norm/norm_linear.h" + +namespace atb_speed { +namespace common { +constexpr int DEFAULT_TOPK_SCALE = -1; + +struct All2AllMatmulParam { + int32_t topk = 2; + uint32_t numOfExperts = 8; + uint32_t numOfDeviceExperts = 8; + bool gateUpTransposeB = false; + bool downTransposeB = false; + int32_t scaledTopk = -1; + int moeEpRank = 0; + int moeEpSize = 1; + std::string lcclMoeEpDomain = ""; + HcclComm lcclMoeEpHcclComm = nullptr; +}; + +atb::Status CreateAll2AllMatmulOperation(const All2AllMatmulParam ¶m, atb::Operation **operation); +} +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/integrated_gmm.cpp b/tests/proftest/layer_test_framework/operations/fusion/moe/integrated_gmm.cpp new file mode 100644 index 00000000..142bcf74 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/integrated_gmm.cpp @@ -0,0 +1,393 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +* +* * Licensed under the Apache License, Version 2.0 (the "License"); +* * you may not use this file except in compliance with the License. +* * You may obtain a copy of the License at +* * +* * http://www.apache.org/licenses/LICENSE-2.0 +* * +* * Unless required by applicable law or agreed to in writing, software +* * distributed under the License is distributed on an "AS IS" BASIS, +* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* * See the License for the specific language governing permissions and +* * limitations under the License. +* */ +#include "integrated_gmm.h" +#include +#include +#include "operations/aclnn/ops/grouped_matmul_operation.h" +#include "operations/aclnn/ops/grouped_matmul_swiglu_operation.h" +#include "operations/aclnn/ops/dynamic_quant_operation.h" +#include "atb_speed/base/event_manager.h" + +namespace atb_speed { +namespace common { + +static const int IDX2 = 2; +static const int IDX3 = 3; +static const int IDX6 = 6; + +std::map> GetInteGmmInTensorCandidates() +{ + std::map> inteGmmInTensorCandidates = { + {"default", { + "in_hiddenstates", "in_weight_expert", "in_bias_expert", "in_descale_expert", + "in_offset_expert", "in_scale_expert", "in_compress_idx_expert", "in_group_list"}, + }, + {"skip_quant", {"in_dynamic_scale"}}, + {"gmm_swiglu_quant", { + "in_mlp_down_weight_expert", "in_mlp_down_bias_expert", "in_mlp_down_descale_expert", + "in_mlp_down_offset_expert", "in_mlp_down_scale_expert", "in_mlp_down_compress_idx_expert"}, + } + }; + return inteGmmInTensorCandidates; +} + +int CalcGmmQuantType(const IntegratedGmmParam ¶m) +{ + int gmmQuantType = 0; + int tempQuantType = 0; + if (param.isUp) { + tempQuantType = atb_speed::common::GetLinearQuantType( + param.denseQuantType == atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED ? + param.packQuantType : param.denseQuantType, + param.moeLinearQuantType[IntegratedGmmIdx::MOE_MLP_GATE_IDX], false); + } else { + tempQuantType = atb_speed::common::GetLinearQuantType( + param.denseQuantType == atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED ? + param.packQuantType : param.denseQuantType, + param.moeLinearQuantType[IntegratedGmmIdx::MOE_MLP_DOWN_IDX], false); + } + switch (tempQuantType) { + case LinearQuantType::NO_QUANT: + gmmQuantType = GmmQuantType::NONE; + break; + case LinearQuantType::LINEAR_W8A8_DYNAMIC_QUANT: + case LinearQuantType::LINEAR_W8A8_DYNAMIC_DEQUANT: + gmmQuantType = GmmQuantType::W8A8_TOKEN; + break; + case LinearQuantType::W8A16: + gmmQuantType = GmmQuantType::W8A16_CHANNEL; + break; + case LinearQuantType::W4A16: + gmmQuantType = GmmQuantType::W4A16_CHANNEL; + break; + case LinearQuantType::LINEAR_W4A8_DYNAMIC_DEQUANT: + case LinearQuantType::LINEAR_W4A8_DYNAMIC_QUANT: + gmmQuantType = GmmQuantType::W4A8_GROUP; + break; + default: + gmmQuantType = GmmQuantType::W8A8_CHANNEL; + break; + } + + ATB_SPEED_LOG_DEBUG(gmmQuantType); + return gmmQuantType; +} + +std::map ConstructTensorMap( + const IntegratedGmmParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &interTensorNum) +{ + auto inteGmmInTensorCandidates = GetInteGmmInTensorCandidates(); + std::vector inTensorList = {}; + std::vector interTensorList = {}; + std::vector outTensorList = {"out_gmm_result"}; + + if (param.enableAtlasGMMFused && !param.enableGMMSwigluQuant) { + ATB_SPEED_LOG_WARN("Please check if the world size is greater than 16. \ + IntegratedGmmParam enableAtlasGMMFused must set under enableGMMSwigluQuant is True! \ + This may be an unexpected behavior"); + } + AddTensorToList(inteGmmInTensorCandidates, "default", inTensorList); + if (param.enableGMMSwigluQuant) { + AddTensorToList(inteGmmInTensorCandidates, "gmm_swiglu_quant", inTensorList); + } + if (param.skipQuant) { + AddTensorToList(inteGmmInTensorCandidates, "skip_quant", inTensorList); + } + + int gmmQuantType = CalcGmmQuantType(param); + if ((gmmQuantType == GmmQuantType::W8A8_TOKEN || gmmQuantType == GmmQuantType::W4A8_GROUP) && !param.skipQuant) { + interTensorList.push_back("intermediate_quant_out"); + interTensorList.push_back("intermediate_dynamic_scale"); + } + if (param.enableGMMSwigluQuant && !param.enableAtlasGMMFused) { + interTensorList.push_back("intermediate_dynamic_scale_1"); + interTensorList.push_back("intermediate_swiglu_quant_out"); + } + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + interTensorNum = interTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, interTensorList); +} + +int64_t SetAclnnDynamicQuantNode( + std::map &tensorMap, + atb::GraphParam &opGraph, size_t &nodeId) +{ + atb::Node &dynamicQuantNode = opGraph.nodes.at(nodeId++); + dynamicQuantNode.operation = new atb_speed::common::DynamicQuantOperation("DynamicQuantNode"); + dynamicQuantNode.inTensorIds = {GetTensorIdx(tensorMap, "in_hiddenstates")}; + dynamicQuantNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_quant_out"), + GetTensorIdx(tensorMap, "intermediate_dynamic_scale")}; + ATB_SPEED_LOG_DEBUG("create dynamic quant"); + return atb::NO_ERROR; +} + +atb::Status CreateW8A8Token( + std::map &tensorMap, + const IntegratedGmmParam ¶m, atb::Node &gmmNode) +{ + ATB_SPEED_LOG_DEBUG("push back W8A8_TOKEN"); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_scale_expert")); + if (param.skipQuant) { + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_dynamic_scale")); + } else { + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_dynamic_scale")); + } + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_group_list")); + gmmNode.inTensorReshapeFuncs.resize(gmmNode.inTensorIds.size()); + gmmNode.inTensorReshapeFuncs[IDX2] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + ATB_SPEED_LOG_DEBUG(oldShape.dimNum); + newShape.dimNum = IDX2; + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = oldShape.dims[1]; + }; + if (param.enableGMMSwigluQuant) { + gmmNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_dynamic_scale_1")); + } + ATB_SPEED_LOG_DEBUG("inTensorReshapeFuncs success"); + return atb::NO_ERROR; +} + +atb::Status CreateW4A8( + std::map &tensorMap, + const IntegratedGmmParam ¶m, atb::Node &gmmNode) + +{ + ATB_SPEED_LOG_DEBUG("push back W4A8"); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_scale_expert")); + if (param.skipQuant) { + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_dynamic_scale")); + } else { + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_dynamic_scale")); + } + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_group_list")); + gmmNode.inTensorReshapeFuncs.resize(gmmNode.inTensorIds.size()); + ATB_SPEED_LOG_DEBUG("CreateW4A8 success"); + return atb::NO_ERROR; +} + +atb::Status CreateA16Channel( + std::map &tensorMap, + atb::Node &gmmNode, const IntegratedGmmParam ¶m) +{ + ATB_SPEED_LOG_DEBUG("push back W4A16 or W8A16"); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_scale_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_offset_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_group_list")); + gmmNode.inTensorReshapeFuncs.resize(gmmNode.inTensorIds.size()); + if (param.quantGroupSize == 0) { + ATB_SPEED_LOG_DEBUG("W4A16 or W8A16 quant per-channel"); + int kDim = param.transposeB ? 1 : 2; // number of dim k + gmmNode.inTensorReshapeFuncs[IDX2] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + ATB_SPEED_LOG_DEBUG(oldShape.dimNum); + newShape.dimNum = IDX2; // dimNum: 2 + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = oldShape.dims[kDim]; + }; + gmmNode.inTensorReshapeFuncs[IDX3] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + ATB_SPEED_LOG_DEBUG(oldShape.dimNum); + newShape.dimNum = IDX2; // dimNum: 2 + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = oldShape.dims[kDim]; + }; + } + + ATB_SPEED_LOG_DEBUG("inTensorReshapeFuncs success"); + return atb::NO_ERROR; +} + +atb::Status CreateGmmMixed(std::map &tensorMap, atb::GraphParam &opGraph, size_t &nodeId, + const IntegratedGmmParam ¶m, int gmmQuantType) +{ + atb::Node &gmmNode = opGraph.nodes.at(nodeId++); + atb::infer::GmmDeqSwigluQuantGmmDeqParam gmmSwigluParam; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gmmSwigluParam, &gmmNode.operation)); + gmmNode.outTensorIds = {GetTensorIdx(tensorMap, "out_gmm_result")}; + + gmmNode.inTensorIds = {GetTensorIdx(tensorMap, "in_hiddenstates"), + GetTensorIdx(tensorMap, "in_weight_expert")}; + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_scale_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_dynamic_scale")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_group_list")); + if (gmmQuantType == GmmQuantType::W8A8_TOKEN) { + gmmNode.inTensorIds.push_back( + GetTensorIdx(tensorMap, "in_mlp_down_weight_expert")); + } + if (param.hasBias) { + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_down_bias_expert")); + } + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_down_scale_expert")); + gmmNode.inTensorReshapeFuncs.resize(gmmNode.inTensorIds.size()); + gmmNode.inTensorReshapeFuncs[IDX2] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + ATB_SPEED_LOG_DEBUG(oldShape.dimNum); + newShape.dimNum = IDX2; + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = oldShape.dims[1]; + }; + gmmNode.inTensorReshapeFuncs[IDX6] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + ATB_SPEED_LOG_DEBUG(oldShape.dimNum); + newShape.dimNum = IDX2; + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = oldShape.dims[1]; + }; + ATB_SPEED_LOG_DEBUG("GmmSwigluQuant calculation success"); + ATB_SPEED_LOG_DEBUG("GmmSwigluQuant calculation success" << param.outDataType); + return atb::NO_ERROR; +} + +// Op1 - GMM +atb::Status CreateGmm(std::map &tensorMap, atb::GraphParam &opGraph, size_t &nodeId, + const IntegratedGmmParam ¶m, int gmmQuantType) +{ + atb::Node &gmmNode = opGraph.nodes.at(nodeId++); + atb_speed::common::AclNNGroupedMatmulParam gmmParam; + gmmParam.quantType = gmmQuantType; + gmmParam.outDataType = param.outDataType; + gmmParam.transposeB = param.transposeB; + gmmParam.hasBias = param.hasBias; + if (param.enableGMMSwigluQuant) { + atb_speed::common::AclNNGroupedSwigluMatmulParam gmmSwigluParam; + gmmSwigluParam.quantType = gmmQuantType; + gmmSwigluParam.outDataType = param.outDataType; + gmmSwigluParam.transposeB = param.transposeB; + gmmNode.operation = new atb_speed::common::GroupedMatmulSwigluOperation("gmmSwigluNode", gmmSwigluParam); + gmmNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_swiglu_quant_out")}; + } else { + gmmNode.operation = new atb_speed::common::GroupedMatmulOperation("gmmNode", gmmParam); + gmmNode.outTensorIds = {GetTensorIdx(tensorMap, "out_gmm_result")}; + } + if ((gmmQuantType == GmmQuantType::W8A8_TOKEN || gmmQuantType == GmmQuantType::W4A8_GROUP) && !param.skipQuant) { + gmmNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_quant_out"), + GetTensorIdx(tensorMap, "in_weight_expert")}; + } else { + gmmNode.inTensorIds = {GetTensorIdx(tensorMap, "in_hiddenstates"), + GetTensorIdx(tensorMap, "in_weight_expert")}; + } + if (param.hasBias) { + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_bias_expert")); + } + switch (gmmParam.quantType) { + case GmmQuantType::W8A16_CHANNEL: + case GmmQuantType::W4A16_CHANNEL: + CHECK_OPERATION_STATUS_RETURN(CreateA16Channel(tensorMap, gmmNode, param)); + break; + case GmmQuantType::W8A8_CHANNEL: + ATB_SPEED_LOG_ERROR("MoE does not support W8A8_CHANNEL"); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_scale_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_compress_idx_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_group_list")); + break; + case GmmQuantType::W8A8_TOKEN: + CHECK_OPERATION_STATUS_RETURN(CreateW8A8Token(tensorMap, param, gmmNode)); + break; + case GmmQuantType::W4A8_GROUP: + CHECK_OPERATION_STATUS_RETURN(CreateW4A8(tensorMap, param, gmmNode)); + break; + default: + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_group_list")); + break; + } + ATB_SPEED_LOG_DEBUG("GMM calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateGmm1(std::map &tensorMap, + atb::GraphParam &opGraph, size_t &nodeId, + const IntegratedGmmParam ¶m, int gmmQuantType) +{ + atb::Node &gmmNode = opGraph.nodes.at(nodeId++); + atb_speed::common::AclNNGroupedMatmulParam gmmParam; + gmmParam.quantType = gmmQuantType; + gmmParam.outDataType = param.outDataType; + gmmParam.transposeB = param.downTransposeB; + gmmParam.hasBias = param.hasBias; + ATB_SPEED_LOG_DEBUG("Calc GmmQuantType success"); + gmmNode.operation = new atb_speed::common::GroupedMatmulOperation("gmmNode1", gmmParam); + gmmNode.outTensorIds = {GetTensorIdx(tensorMap, "out_gmm_result")}; + gmmNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_swiglu_quant_out"), + GetTensorIdx(tensorMap, "in_mlp_down_weight_expert")}; + if (param.hasBias) { + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_down_bias_expert")); + } + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_down_scale_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_dynamic_scale_1")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_group_list")); + gmmNode.inTensorReshapeFuncs.resize(gmmNode.inTensorIds.size()); + gmmNode.inTensorReshapeFuncs[IDX2] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + ATB_SPEED_LOG_DEBUG(oldShape.dimNum); + newShape.dimNum = IDX2; + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = oldShape.dims[1]; + }; + ATB_SPEED_LOG_DEBUG("GMM calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateRecord(const IntegratedGmmParam ¶m, atb::GraphParam &opGraph, size_t &nodeId, + atb_speed::EventAction eventAction, const std::string &cvKey) +{ + if (param.enableCVOverlap) { + atb::Node &recordNode = opGraph.nodes.at(nodeId++); + recordNode.inTensorIds = {}; + recordNode.outTensorIds ={}; + CHECK_OPERATION_STATUS_RETURN(atb_speed::EventManager::GetInstance().RecordEvent( + recordNode.operation, + eventAction, + cvKey)); + ATB_SPEED_LOG_DEBUG("Record event success"); + } + return atb::NO_ERROR; +} + +atb::Status CreateIntegratedGmmOperation(const IntegratedGmmParam ¶m, atb::Operation **operation) +{ + std::shared_ptr batchDimPtr = std::make_shared(0); + atb::GraphParam opGraph; + opGraph.name = "integrated_gmm"; + std::map tensorMap = ConstructTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + int gmmQuantType = CalcGmmQuantType(param); + uint64_t nodeCount = uint64_t(1); + if ((gmmQuantType == GmmQuantType::W8A8_TOKEN || gmmQuantType == GmmQuantType::W4A8_GROUP) && !param.skipQuant) { + nodeCount = uint64_t(2); // 2: the number of nodes needed to compelte the calculation + } + if (param.enableGMMSwigluQuant && !param.enableAtlasGMMFused) { + nodeCount += 1; + } + opGraph.nodes.resize(nodeCount); + size_t nodeId = 0; + if ((gmmQuantType == GmmQuantType::W8A8_TOKEN || gmmQuantType == GmmQuantType::W4A8_GROUP) && !param.skipQuant) { + CHECK_OPERATION_STATUS_RETURN(SetAclnnDynamicQuantNode(tensorMap, opGraph, nodeId)); + } + if (param.enableCVOverlap) { + CHECK_OPERATION_STATUS_RETURN(CreateRecord( + param, opGraph, nodeId, atb_speed::EventAction::POP, atb_speed::common::VECTOR_CONTROL)); + } + if (param.enableAtlasGMMFused) { + CHECK_OPERATION_STATUS_RETURN(CreateGmmMixed(tensorMap, opGraph, nodeId, param, gmmQuantType)); + } else { + CHECK_OPERATION_STATUS_RETURN(CreateGmm(tensorMap, opGraph, nodeId, param, gmmQuantType)); + if (param.enableGMMSwigluQuant && gmmQuantType == GmmQuantType::W8A8_TOKEN) { + CHECK_OPERATION_STATUS_RETURN(CreateGmm1(tensorMap, opGraph, nodeId, param, gmmQuantType)); + } + } + CREATE_OPERATION(opGraph, operation); + return atb::NO_ERROR; +} +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/integrated_gmm.h b/tests/proftest/layer_test_framework/operations/fusion/moe/integrated_gmm.h new file mode 100644 index 00000000..c513200d --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/integrated_gmm.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_INTEGRATED_GMM_OPERATION_H +#define ATB_SPEED_MODELS_INTEGRATED_GMM_OPERATION_H +#include +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/log.h" +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/linear/linear_parallel.h" +#include "operations/fusion/norm/norm_linear.h" + +namespace atb_speed { +namespace common { +enum IntegratedGmmIdx : int { + ROUTER_IDX = 0, + MOE_MLP_GATE_IDX, + MOE_MLP_UP_IDX, + MOE_MLP_DOWN_IDX +}; +struct IntegratedGmmParam { + /// The quantization tpe of the linear transformation of this sub-graph + std::vector moeLinearQuantType = {}; + /// A flag indicating whether there is bias to the linear transformation of this sub-graph + bool hasBias = false; + /// A flag indicating whether the linear transformation is the `UP` or the `DOWN` stage of FFN + bool isUp = true; + /// The data type of the output of the linear transformation + aclDataType outDataType = ACL_FLOAT16; + /// A flag indicating whether the second matrix of the matrix multiplication needs to be transposed + bool transposeB = false; + /// A flag indicating whether the second matrix of the matrix multiplication needs to be transposed + bool downTransposeB = false; + /// The quantization type of the packed weights + int packQuantType = atb_speed::common::PackQuantType::ALL_FP; + /// The quantization type used to facilitate the calculation of the quantization type of the linear operation + int denseQuantType = atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED; + /// The group size used for dequantizing the weight tensor in the per-group quantization approach + int quantGroupSize = 0; + /// A flag indicating whether or not to skip the quantization step + bool skipQuant = false; + /// A flag indicating whether the model use Moe parallel + bool enableMoeParallel = false; + /// A flag indicating whether the model use cube and vector parallel + bool enableCVOverlap = false; + /// A flag indicating whether or not to use integrated GMM+Swiglu+quant operators. + bool enableGMMSwigluQuant = false; + /// A flag indicating whether or not to use fused atb GMM+Swiglu+quant operators instead of aclnn. + bool enableAtlasGMMFused = false; +}; + +/// This function creates a sub-graph that performs grouped-matmul. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateIntegratedGmmOperation(const IntegratedGmmParam ¶m, atb::Operation **operation); +} +} // namespace atb_speed +#endif diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/moe_mlp.cpp b/tests/proftest/layer_test_framework/operations/fusion/moe/moe_mlp.cpp new file mode 100644 index 00000000..021c42b2 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/moe_mlp.cpp @@ -0,0 +1,1065 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +* +* * Licensed under the Apache License, Version 2.0 (the "License"); +* * you may not use this file except in compliance with the License. +* * You may obtain a copy of the License at +* * +* * http://www.apache.org/licenses/LICENSE-2.0 +* * +* * Unless required by applicable law or agreed to in writing, software +* * distributed under the License is distributed on an "AS IS" BASIS, +* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* * See the License for the specific language governing permissions and +* * limitations under the License. +* */ +#include "moe_mlp.h" +#include +#include +#include +#include "operations/fusion/moe/integrated_gmm.h" +#include "operations/fusion/moe/ep/expert_filter.h" +#include "operations/aclnn/ops/finalize_routing_operation.h" +#include "operations/aclnn/ops/moe_init_routing_operation.h" +#include "operations/aclnn/ops/moe_init_routing_quant_operation.h" +#include "operations/aclnn/ops/moe_compute_expert_tokens_operation.h" +#include "operations/aclnn/ops/moetoken_unpermute_operation.h" +#include "operations/aclnn/ops/inplacemasked_filltensor_operation.h" +#include "operations/aclnn/ops/grouped_matmul_operation.h" +#include "atb_speed/base/event_manager.h" +#include "operations/aclnn/utils/utils.h" +#include "operations/aclnn/ops/moe_distribute_combine_operation.h" +#include "operations/aclnn/ops/moe_distribute_dispatch_operation.h" +#include "operations/aclnn/ops/quant_gmm_dequant_operation.h" +#include "operations/aclnn/ops/moe_distribute_combine_v2_operation.h" +#include "operations/aclnn/ops/moe_distribute_dispatch_v2_operation.h" +#include "operations/aclnn/ops/len_operation.h" +#include "operations/aclnn/ops/minimum_operation.h" + +namespace atb_speed { +namespace common { + + +int CalcUpGmmQuantType(const MoeMlpParam ¶m) +{ + int gmmQuantType = 0; + int tempQuantType = atb_speed::common::GetLinearQuantType( + param.denseQuantType == atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED ? + param.packQuantType : param.denseQuantType, + param.moeLinearQuantType[IntegratedGmmIdx::MOE_MLP_GATE_IDX], false); + switch (tempQuantType) { + case LinearQuantType::NO_QUANT: + gmmQuantType = GmmQuantType::NONE; + break; + case LinearQuantType::LINEAR_W8A8_DYNAMIC_QUANT: + case LinearQuantType::LINEAR_W8A8_DYNAMIC_DEQUANT: + gmmQuantType = GmmQuantType::W8A8_TOKEN; + break; + case LinearQuantType::W8A16: + gmmQuantType = GmmQuantType::W8A16_CHANNEL; + break; + case LinearQuantType::W4A16: + gmmQuantType = GmmQuantType::W4A16_CHANNEL; + break; + case LinearQuantType::LINEAR_W4A8_DYNAMIC_DEQUANT: + case LinearQuantType::LINEAR_W4A8_DYNAMIC_QUANT: + gmmQuantType = GmmQuantType::W4A8_GROUP; + break; + default: + gmmQuantType = GmmQuantType::W8A8_CHANNEL; + break; + } + return gmmQuantType; +} + +bool IsGMMSwigluQuant(const int gmmQuantType, const MoeMlpParam ¶m) +{ + return gmmQuantType == GmmQuantType::W8A8_TOKEN && param.enableGMMSwigluQuant; +} + +std::map> GetMoeMlpInTensorCandidates() +{ + std::map> moeMlpInTensorCandidates = { + {"default", { + "in_hiddenstates", "in_mlp_gateup_weight_expert", "in_mlp_gateup_bias_expert", + "in_mlp_gateup_descale_expert", "in_mlp_gateup_offset_expert", "in_mlp_gateup_scale_expert", + "in_mlp_gateup_compress_idx_expert", "in_mlp_down_weight_expert", + "in_mlp_down_bias_expert", "in_mlp_down_descale_expert", "in_mlp_down_offset_expert", + "in_mlp_down_scale_expert", "in_mlp_down_compress_idx_expert", "in_expert_array", "in_selected_experts", + "in_expert_weight"}, + }, + {"ep", { + "in_zero_hot", "in_start_expert_idx", "in_device_expert_count", "in_padding_idx"} + } + }; + return moeMlpInTensorCandidates; +} + +std::map> GetMoeMlpInterTensorCandidates() +{ + std::map> moeMlpInterTensorCandidates = { + {"default", { + "intermediate_idx", "intermediate_weight_idx", "intermediate_dummy_zero", + "intermediate_dummy_one", "intermediate_rev_idx", "intermediate_group_list", + "intermediate_sorted_hiddenstates", "intermediate_rev_sorted_hiddenstates", + "intermediate_matmul_gate_up_out", "intermediate_swish_out", "intermediate_mlp_out", + "intermediate_mlp_out_weighted", "intermediate_sorted_weight"} + }, + {"default_w8a8_token", { + "intermediate_idx", "intermediate_weight_idx", "intermediate_dummy_zero", + "intermediate_dummy_one", "intermediate_rev_idx", "intermediate_group_list", + "intermediate_sorted_hiddenstates", "intermediate_rev_sorted_hiddenstates", + "intermediate_mlp_out", "intermediate_mlp_out_weighted", "intermediate_sorted_weight"} + }, + {"enableFusedRouting", { + "intermediate_idx", "intermediate_group_list", "intermediate_sorted_hiddenstates", + "intermediate_matmul_gate_up_out", "intermediate_swish_out", "intermediate_mlp_out"} + }, + {"disable_mc2", { + "intermediate_group_list_int64"} + }, + {"enable_mc2", { + "intermediate_ep_recv_counts", "intermediate_tp_recv_counts", "intermediate_gmm0_deqscale", + "intermediate_expand_expert_weight"} + }, + {"enableFusedRouting_w8a8_token", { + "intermediate_idx", "intermediate_group_list", "intermediate_sorted_hiddenstates", "intermediate_mlp_out"} + }, + {"disable_swiglu", { + "intermediate_matmul_gate_out", "intermediate_matmul_up_out", "intermediate_swish_out_internal"} + }, + {"enable_init_quant", { + "intermediate_tokens_before_capacity", "intermediate_sorted_hiddenstates_dequant_scale"} + }, + {"enable_swiglu_quant", { + "intermedaite_swiglu_dequant_scale"} + }, + {"ep", { + "intermediate_group_list_sliced", "intermediate_expert_weight"} + }, + {"dynanmic_ep", { + "intermediate_selected_experts"} + }, + {"gating_dp", { + "intermediate_group_list_sliced", "intermediate_expert_weight", "intermediate_selected_experts"} + }, + {"initrouting_cutoff", { + "intermediate_sorted_hiddenstates_len", "intermediate_group_list_filtered"}} + }; + return moeMlpInterTensorCandidates; +} + +atb::Status AddIntermediateTensor(const MoeMlpParam ¶m, std::vector& interTensorList) +{ + auto moeMlpInterTensorCandidates = GetMoeMlpInterTensorCandidates(); + bool isGMMSwigluQuant = IsGMMSwigluQuant(CalcUpGmmQuantType(param), param); + if (param.enableFusedRouting) { + AddTensorToList(moeMlpInterTensorCandidates, param.enableMoeDistribute ? + "enable_mc2" : "disable_mc2", interTensorList); + AddTensorToList(moeMlpInterTensorCandidates, isGMMSwigluQuant ? "enableFusedRouting_w8a8_token" : + "enableFusedRouting", interTensorList); + } else { + AddTensorToList(moeMlpInterTensorCandidates, isGMMSwigluQuant ? + "default_w8a8_token" : "default", interTensorList); + } + if (param.enableInitQuant && !param.enableMoeDistribute) { + AddTensorToList(moeMlpInterTensorCandidates, "enable_init_quant", interTensorList); + } + if (param.enableSwigluQuant) { + AddTensorToList(moeMlpInterTensorCandidates, "enable_swiglu_quant", interTensorList); + } + if (!param.supportSwiGLU) { + AddTensorToList(moeMlpInterTensorCandidates, "disable_swiglu", interTensorList); + } + if (param.hasMoeEp && !param.enableMoeDistribute) { + if (!param.enableGatingDp) { + AddTensorToList(moeMlpInterTensorCandidates, "ep", interTensorList); + if (!param.shiftedTopK) { + AddTensorToList(moeMlpInterTensorCandidates, "dynanmic_ep", interTensorList); + } + } else { + AddTensorToList(moeMlpInterTensorCandidates, "gating_dp", interTensorList); + } + } + if (param.enableInitRoutingCutoff) { + AddTensorToList(moeMlpInterTensorCandidates, "initrouting_cutoff", interTensorList); + } + return atb::NO_ERROR; +} + +std::map ConstructTensorMap( + const MoeMlpParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &interTensorNum) +{ + auto moeMlpInTensorCandidates = GetMoeMlpInTensorCandidates(); + std::vector inTensorList = {}; + std::vector interTensorList = {}; + std::vector outTensorList = {"out_moe_mlp_result"}; + AddTensorToList(moeMlpInTensorCandidates, "default", inTensorList); + if (param.hasMoeEp) { + AddTensorToList(moeMlpInTensorCandidates, "ep", inTensorList); + } + + AddIntermediateTensor(param, interTensorList); + + if (param.enableExpertCumSumOutput) { + if (param.enableFusedRouting && !param.enableMoeDistribute) { + interTensorList.erase(std::remove(interTensorList.begin(), interTensorList.end(), + "intermediate_group_list_int64"), interTensorList.end()); + outTensorList.push_back("intermediate_group_list_int64"); + } else { + interTensorList.erase(std::remove(interTensorList.begin(), interTensorList.end(), + "intermediate_group_list"), interTensorList.end()); + outTensorList.push_back("intermediate_group_list"); + } + } + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + interTensorNum = interTensorList.size(); + return GetTensorMap(inTensorList, outTensorList, interTensorList); +} + +atb::Status CreateExpertFilter( + std::map &tensorMap, const MoeMlpParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node expertFilterNode; + atb_speed::common::ExpertFilterParam expertFilterParam; + expertFilterParam.numOfExperts = param.numOfExperts; + expertFilterParam.deviceExpert = param.deviceExpert; + expertFilterParam.shiftedTopK = param.shiftedTopK; + expertFilterParam.isBF16 = param.isBF16; + expertFilterParam.enableGatingDp = param.enableGatingDp; + atb_speed::common::CreateExpertFilterOperation(expertFilterParam, &expertFilterNode.operation); + + expertFilterNode.inTensorIds = {GetTensorIdx(tensorMap, "in_selected_experts"), + GetTensorIdx(tensorMap, "in_expert_weight"), + GetTensorIdx(tensorMap, "in_start_expert_idx"), + GetTensorIdx(tensorMap, "in_device_expert_count"), + GetTensorIdx(tensorMap, "in_zero_hot")}; + + if (param.shiftedTopK && !param.enableGatingDp) { + expertFilterNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_expert_weight")}; + } else { + expertFilterNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_selected_experts"), + GetTensorIdx(tensorMap, "intermediate_expert_weight")}; + } + opGraph.nodes.push_back(expertFilterNode); + ATB_SPEED_LOG_DEBUG("InitRouting calculation success"); + return atb::NO_ERROR; +} + +// Step 1: hidden state permutation +atb::Status CreateInitRoutingQuant( + std::map &tensorMap, const MoeMlpParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node initRoutingNode; + atb_speed::common::MoeInitRoutingQuantParam initRoutingParam; + initRoutingParam.topkNum = param.topk; + /// deepseek模型开启scaledTopk功能,其余模型不开启 + initRoutingParam.scaledTopk = param.scaledTopk; + initRoutingParam.enableInitRoutingCutoff = param.enableInitRoutingCutoff; + initRoutingParam.expertNum = param.numOfExperts; + int gmmQuantType = CalcUpGmmQuantType(param); + if (gmmQuantType == GmmQuantType::W4A8_GROUP) { + initRoutingParam.expertTokensCoutOrCumsumFlag = 2; // 2 : W4A8_GROUP_Mutmal 不适用累加和形式 + initRoutingParam.enableInitRoutingCutoff = false; + } + initRoutingNode.operation = new atb_speed::common::MoeInitRoutingQuantOperation("MoeInitRoutingQuantOperation" \ + + std::to_string(gmmQuantType), + initRoutingParam); + initRoutingNode.inTensorIds = {GetTensorIdx(tensorMap, "in_hiddenstates"), + GetTensorIdx(tensorMap, param.hasMoeEp && !param.shiftedTopK ? \ + "intermediate_selected_experts" : (param.enableGatingDp ? \ + "intermediate_selected_experts" : "in_selected_experts"))}; + initRoutingNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_sorted_hiddenstates"), + GetTensorIdx(tensorMap, "intermediate_idx"), + GetTensorIdx(tensorMap, "intermediate_group_list"), + GetTensorIdx(tensorMap, "intermediate_tokens_before_capacity"), + GetTensorIdx(tensorMap, "intermediate_sorted_hiddenstates_dequant_scale")}; + opGraph.nodes.push_back(initRoutingNode); + ATB_SPEED_LOG_DEBUG("InitRouting calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateInitRouting( + std::map &tensorMap, const MoeMlpParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node initRoutingNode; + atb_speed::common::MoeInitRoutingParam initRoutingParam; + initRoutingParam.topkNum = param.topk; + /// deepseek模型开启scaledTopk功能,其余模型不开启 + initRoutingParam.scaledTopk = param.scaledTopk; + initRoutingParam.enableInitRoutingCutoff = param.enableInitRoutingCutoff; + initRoutingParam.expertNum = param.numOfExperts; + int gmmQuantType = CalcUpGmmQuantType(param); + if (gmmQuantType == GmmQuantType::W4A8_GROUP) { + initRoutingParam.expertTokensCoutOrCumsumFlag = 2; // 2 : W4A8_GROUP_Mutmal 不适用累加和形式 + initRoutingParam.enableInitRoutingCutoff = false; + } + initRoutingNode.operation = new atb_speed::common::MoeInitRoutingOperation("MoeInitRoutingOperation" \ + + std::to_string(gmmQuantType), + initRoutingParam); + initRoutingNode.inTensorIds = {GetTensorIdx(tensorMap, "in_hiddenstates"), + GetTensorIdx(tensorMap, param.hasMoeEp && !param.shiftedTopK ? \ + "intermediate_selected_experts" : (param.enableGatingDp ? \ + "intermediate_selected_experts" : "in_selected_experts"))}; + initRoutingNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_sorted_hiddenstates"), + GetTensorIdx(tensorMap, "intermediate_idx"), + GetTensorIdx(tensorMap, "intermediate_group_list")}; + opGraph.nodes.push_back(initRoutingNode); + ATB_SPEED_LOG_DEBUG("InitRouting calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateComputeExpertSlice( + std::map &tensorMap, const MoeMlpParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node sliceNode; + atb::infer::SliceParam sliceParam; + sliceParam.offsets.resize(1); + sliceParam.offsets[0] = 0; + sliceParam.size.resize(1); + sliceParam.size[0] = param.numOfDeviceExperts; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(sliceParam, &sliceNode.operation)); + sliceNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_group_list")}; + sliceNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_group_list_sliced")}; + opGraph.nodes.push_back(sliceNode); + ATB_SPEED_LOG_DEBUG("sliceNode calculation success"); + return atb::NO_ERROR; +} + + +atb::Status CreateExpandedXLen( + std::map &tensorMap, + atb::GraphParam &opGraph) +{ + atb::Node lenNode; + lenNode.operation = new atb_speed::common::LenOperation("LenOperation"); + lenNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_sorted_hiddenstates")}; + lenNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_sorted_hiddenstates_len")}; + opGraph.nodes.push_back(lenNode); + ATB_SPEED_LOG_DEBUG("LenOperation calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateGroupListFilter( + std::map &tensorMap, + const MoeMlpParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node minimumNode; + minimumNode.operation = new atb_speed::common::MinimumOperation("MinimumOperation"); + minimumNode.inTensorIds = { + GetTensorIdx(tensorMap, param.hasMoeEp ? "intermediate_group_list_sliced" : "intermediate_group_list"), + GetTensorIdx(tensorMap, "intermediate_sorted_hiddenstates_len")}; + minimumNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_group_list_filtered")}; + opGraph.nodes.push_back(minimumNode); + ATB_SPEED_LOG_DEBUG("MinimumOperation calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateCast(std::map &tensorMap, const MoeMlpParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node castNode; + atb::infer::ElewiseParam castParam; + castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castParam.outTensorType = ACL_INT64; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castParam, &castNode.operation)); + castNode.inTensorIds = { + GetTensorIdx(tensorMap, param.enableInitRoutingCutoff ? "intermediate_group_list_filtered" : ( + param.hasMoeEp ? "intermediate_group_list_sliced" : "intermediate_group_list") + ) + }; + castNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_group_list_int64")}; + opGraph.nodes.push_back(castNode); + ATB_SPEED_LOG_DEBUG("Cast calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateGating( + std::map &tensorMap, const MoeMlpParam ¶m, std::shared_ptr batchDimPtr, + atb::GraphParam &opGraph) +{ + atb::Node gatingNode; + CHECK_PARAM_NE(param.topk, 0); + CHECK_PARAM_NE(param.numOfExperts, 0); + atb::infer::GatingParam gatingParam; + gatingParam.topkExpertNum = param.topk; + gatingParam.cumSumNum = param.numOfExperts; + gatingParam.cumSumInt64 = true; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatingParam, &gatingNode.operation)); + gatingNode.inTensorIds = { + GetTensorIdx(tensorMap, "in_selected_experts"), GetTensorIdx(tensorMap, "in_expert_array")}; + gatingNode.outTensorIds = { + GetTensorIdx(tensorMap, "intermediate_idx"), GetTensorIdx(tensorMap, "intermediate_group_list"), + GetTensorIdx(tensorMap, "intermediate_weight_idx")}; + gatingNode.inTensorReshapeFuncs.resize(gatingNode.inTensorIds.size()); + gatingNode.inTensorReshapeFuncs[0] = [batchDimPtr](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + opGraph.nodes.push_back(gatingNode); + ATB_SPEED_LOG_DEBUG("Gating calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateGather0(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node gatherNode0; + atb::infer::GatherParam gatherParam; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatherParam, &gatherNode0.operation)); + gatherNode0.inTensorIds = {GetTensorIdx(tensorMap, "in_hiddenstates"), GetTensorIdx(tensorMap, "intermediate_idx")}; + gatherNode0.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_sorted_hiddenstates")}; + opGraph.nodes.push_back(gatherNode0); + ATB_SPEED_LOG_DEBUG("Gather0 calculation success"); + return atb::NO_ERROR; +} + +// Step 2: grouped matmul calculation & activation +atb::Status CreateGmm(std::map &tensorMap, + atb::GraphParam &opGraph, const MoeMlpParam ¶m) +{ + atb::Node gmmNode; + atb_speed::common::IntegratedGmmParam gmmParam; + gmmParam.hasBias = (param.packQuantType == atb_speed::common::PackQuantType::ALL_W4A8) ? \ + true : param.hasBias; + gmmParam.isUp = true; + gmmParam.moeLinearQuantType = param.moeLinearQuantType; + gmmParam.packQuantType = param.packQuantType; + gmmParam.transposeB = param.gateUpTransposeB; + gmmParam.downTransposeB = param.downTransposeB; + gmmParam.outDataType = param.isBF16 ? ACL_BF16 : ACL_FLOAT16; + gmmParam.quantGroupSize = param.quantGroupSize; + gmmParam.enableGMMSwigluQuant = param.enableGMMSwigluQuant; + gmmParam.enableAtlasGMMFused = param.enableAtlasGMMFused; + if (param.enableInitQuant || (param.enableMoeDistribute && + (param.packQuantType == atb_speed::common::PackQuantType::ALL_W8A8_DYNAMIC \ + || param.packQuantType == atb_speed::common::PackQuantType::ALL_W4A8))) { + gmmParam.skipQuant = true; + } + if (param.enableCVOverlap) {gmmParam.enableCVOverlap = true;} + CHECK_OPERATION_STATUS_RETURN(CreateIntegratedGmmOperation(gmmParam, &gmmNode.operation)); + gmmNode.inTensorIds = {}; + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_sorted_hiddenstates")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_gateup_weight_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_gateup_bias_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_gateup_descale_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_gateup_offset_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_gateup_scale_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_gateup_compress_idx_expert")); + if (param.enableFusedRouting && !param.enableMoeDistribute) { + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_group_list_int64")); + } else { + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_group_list")); + } + bool isGMMSwigluQuant = IsGMMSwigluQuant(CalcUpGmmQuantType(param), param); + if (isGMMSwigluQuant) { + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_down_weight_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_down_bias_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_down_descale_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_down_offset_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_down_scale_expert")); + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_mlp_down_compress_idx_expert")); + } + gmmNode.outTensorIds = {GetTensorIdx(tensorMap, isGMMSwigluQuant ? "intermediate_mlp_out" : + "intermediate_matmul_gate_up_out")}; + if (gmmParam.skipQuant) { + gmmNode.inTensorIds.push_back(GetTensorIdx(tensorMap, param.enableMoeDistribute ? "intermediate_gmm0_deqscale" + : "intermediate_sorted_hiddenstates_dequant_scale")); + } + opGraph.nodes.push_back(gmmNode); + ATB_SPEED_LOG_DEBUG("GMM calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateQuantGMMDequant( + std::map &tensorMap, + atb::GraphParam &opGraph, const MoeMlpParam ¶m) +{ + atb::Node quantGmmDequantNode; + atb_speed::common::AclNNQuantGMMDequantParam aclnnParam; + aclnnParam.outDataType = param.isBF16 ? ACL_BF16 : ACL_FLOAT16; + aclnnParam.transposeB = true; // param.gateUpTransposeB; + aclnnParam.quantMode = "pertoken"; + quantGmmDequantNode.operation = new atb_speed::common::QuantGMMDequantOperation("QuantGMMDequantOperation", + aclnnParam); + quantGmmDequantNode.inTensorIds = { // 四个输入 + GetTensorIdx(tensorMap, "intermediate_sorted_hiddenstates"), + GetTensorIdx(tensorMap, "in_mlp_gateup_weight_expert"), + GetTensorIdx(tensorMap, "in_mlp_gateup_scale_expert"), + }; + if (param.enableFusedRouting) { + quantGmmDequantNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_group_list_int64")); + } else { + quantGmmDequantNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_group_list")); + } + quantGmmDequantNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_matmul_gate_up_out")}; + quantGmmDequantNode.inTensorReshapeFuncs.resize(quantGmmDequantNode.inTensorIds.size()); + quantGmmDequantNode.inTensorReshapeFuncs[2] = [param](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // [256, 4096, 1] -> [256, 4096] + newShape.dims[0] = param.numOfExperts; + newShape.dims[1] = oldShape.dims[0] / param.numOfExperts * oldShape.dims[1]; + }; + opGraph.nodes.push_back(quantGmmDequantNode); + ATB_SPEED_LOG_DEBUG("QuantGMMDequant calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateActivation( + std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node swishNode; + atb::infer::ActivationParam activationParam; + activationParam.activationType = atb::infer::ActivationType::ACTIVATION_SWIGLU_FORWARD; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(activationParam, &swishNode.operation)); + swishNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_matmul_gate_up_out")}; + swishNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_swish_out")}; + opGraph.nodes.push_back(swishNode); + ATB_SPEED_LOG_DEBUG("Activation calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateActivationQuant( + std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node swigluQuantNode; + atb::infer::SwigluQuantParam swigluQuantParam; + swigluQuantParam.quantType = atb::infer::SwigluQuantParam::QuantType::QUANT_TYPE_PER_TOKEN; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(swigluQuantParam, &swigluQuantNode.operation)); + swigluQuantNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_matmul_gate_up_out")}; + swigluQuantNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_swish_out"), + GetTensorIdx(tensorMap, "intermedaite_swiglu_dequant_scale")}; + opGraph.nodes.push_back(swigluQuantNode); + ATB_SPEED_LOG_DEBUG("ActivationQuant calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateSplit(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node splitNode; + atb::infer::SplitParam splitParam = {1, 2}; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(splitParam, &splitNode.operation)); + splitNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_matmul_gate_up_out")}; + splitNode.outTensorIds = { + GetTensorIdx(tensorMap, "intermediate_matmul_gate_out"), GetTensorIdx(tensorMap, "intermediate_matmul_up_out")}; + opGraph.nodes.push_back(splitNode); + ATB_SPEED_LOG_DEBUG("Split calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateActivationO(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node swishNodeO; + atb::infer::ActivationParam activationParam; + activationParam.activationType = atb::infer::ActivationType::ACTIVATION_SWISH; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(activationParam, &swishNodeO.operation)); + swishNodeO.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_matmul_gate_out")}; + swishNodeO.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_swish_out_internal")}; + opGraph.nodes.push_back(swishNodeO); + ATB_SPEED_LOG_DEBUG("Activation calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateElewiseMul(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node mulNode; + atb::infer::ElewiseParam elewiseParam; + elewiseParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(elewiseParam, &mulNode.operation)); + mulNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_swish_out_internal"), + GetTensorIdx(tensorMap, "intermediate_matmul_up_out")}; + mulNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_swish_out")}; + opGraph.nodes.push_back(mulNode); + ATB_SPEED_LOG_DEBUG("ElewiseMul0 calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateGmm1(std::map &tensorMap, + atb::GraphParam &opGraph, const MoeMlpParam ¶m) +{ + atb::Node gmmDownNode; + atb_speed::common::IntegratedGmmParam gmmParam; + gmmParam.hasBias = param.hasBias; + gmmParam.isUp = false; + gmmParam.moeLinearQuantType = param.moeLinearQuantType; + gmmParam.packQuantType = param.packQuantType; + gmmParam.transposeB = param.downTransposeB; + if (param.packQuantType == atb_speed::common::PackQuantType::ALL_W4A8) { + gmmParam.hasBias = true; + } + gmmParam.outDataType = param.isBF16 ? ACL_BF16 : ACL_FLOAT16; + if (param.enableSwigluQuant) {gmmParam.skipQuant = true;} + CHECK_OPERATION_STATUS_RETURN(CreateIntegratedGmmOperation(gmmParam, &gmmDownNode.operation)); + gmmDownNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_mlp_out")}; + gmmDownNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_swish_out"), + GetTensorIdx(tensorMap, "in_mlp_down_weight_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_bias_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_descale_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_offset_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_scale_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_compress_idx_expert")}; + if (param.enableFusedRouting && !param.enableMoeDistribute) { + gmmDownNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_group_list_int64")); + } else { + gmmDownNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_group_list")); + } + if (param.enableSwigluQuant) { + gmmDownNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermedaite_swiglu_dequant_scale")); + } + opGraph.nodes.push_back(gmmDownNode); + ATB_SPEED_LOG_DEBUG("GMM calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateQuantGMMDequant1( + std::map &tensorMap, + atb::GraphParam &opGraph, const MoeMlpParam ¶m) +{ + atb::Node quantGmmDequantDownNode; + atb_speed::common::AclNNQuantGMMDequantParam aclnnParam; + aclnnParam.outDataType = param.isBF16 ? ACL_BF16 : ACL_FLOAT16; + aclnnParam.transposeB = param.downTransposeB; + aclnnParam.quantMode = "pertoken"; + quantGmmDequantDownNode.operation = new atb_speed::common::QuantGMMDequantOperation("QuantGMMDequantOperation", + aclnnParam); + quantGmmDequantDownNode.inTensorIds = { // 四个输入 + GetTensorIdx(tensorMap, "intermediate_swish_out"), + GetTensorIdx(tensorMap, "in_mlp_down_weight_expert"), + GetTensorIdx(tensorMap, "in_mlp_down_scale_expert"), + }; + if (param.enableFusedRouting) { + quantGmmDequantDownNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_group_list_int64")); + } else { + quantGmmDequantDownNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_group_list")); + } + quantGmmDequantDownNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_mlp_out")}; + quantGmmDequantDownNode.inTensorReshapeFuncs.resize(quantGmmDequantDownNode.inTensorIds.size()); + quantGmmDequantDownNode.inTensorReshapeFuncs[2] = [param](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // [256, 4096, 1] -> [256, 4096] + newShape.dims[0] = param.numOfExperts; + newShape.dims[1] = oldShape.dims[0] / param.numOfExperts * oldShape.dims[1]; + }; + opGraph.nodes.push_back(quantGmmDequantDownNode); + ATB_SPEED_LOG_DEBUG("QuantGMMDequant1 calculation success"); + return atb::NO_ERROR; +} + +// Step 3: hidden state reduction +atb::Status CreateMoeTokenUnpermute( + std::map &tensorMap, const MoeMlpParam ¶m, atb::GraphParam &opGraph) +{ + atb::Node unpermuteNode; + unpermuteNode.operation = new atb_speed::common::MoeTokenUnpermuteOperation("MoeTokenUnpermuteNode"); + unpermuteNode.outTensorIds = {GetTensorIdx(tensorMap, "out_moe_mlp_result")}; + unpermuteNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_mlp_out"), + GetTensorIdx(tensorMap, "intermediate_idx"), + // shiftedTopK原地写 + GetTensorIdx(tensorMap, param.hasMoeEp && !param.shiftedTopK ? \ + "intermediate_expert_weight" : (param.enableGatingDp) ? \ + "intermediate_expert_weight" : "in_expert_weight")}; + opGraph.nodes.push_back(unpermuteNode); + ATB_SPEED_LOG_DEBUG("UnpermuteNode calculation success"); + return atb::NO_ERROR; +} + +template +atb::Status SetMoeDistributeDispatchParam(DispatchParam &dispatchParam, const MoeMlpParam ¶m) +{ + dispatchParam.epRankId = param.moeEpRank; + dispatchParam.epRankSize = param.moeEpSize; + dispatchParam.epCommName = param.moeEpDomain; + dispatchParam.maxDecodeDpTokenSize = param.maxDecodeDpTokenSize; + dispatchParam.moeExpertNum = param.numOfExperts; + dispatchParam.localMoeExpertNum = param.numOfDeviceExperts; + dispatchParam.sharedExpertRankNum = param.numDanglingSharedExperts; + dispatchParam.topk = param.topk; + if (param.packQuantType == atb_speed::common::PackQuantType::ALL_W8A8_DYNAMIC || + param.packQuantType == atb_speed::common::PackQuantType::ALL_W8A8_DYNAMIC_ANTI) { + dispatchParam.quantMode = 2; // 2: 量化模式2 + dispatchParam.isQuant = true; + dispatchParam.quantSmooth = false; + } else { + dispatchParam.quantMode = 0; // 0不量化 + dispatchParam.isQuant = false; + dispatchParam.quantSmooth = false; + } + return atb::NO_ERROR; +} + +atb::Status CreateMoeDistributeDispatch( + std::map &tensorMap, const MoeMlpParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node moeDistributeDispatchNode; + if (param.enableDispatchCombineV2) { + atb_speed::common::MoeDistributeDispatchV2Param dispatchParam; + SetMoeDistributeDispatchParam(dispatchParam, param); + moeDistributeDispatchNode.operation = new atb_speed::common::MoeDistributeDispatchV2Operation( + std::string("MoeDistributeDispatchV2") + std::to_string(param.packQuantType), dispatchParam); + } else { + atb_speed::common::MoeDistributeDispatchParam dispatchParam; + SetMoeDistributeDispatchParam(dispatchParam, param); + moeDistributeDispatchNode.operation = new atb_speed::common::MoeDistributeDispatchOperation( + std::string("MoeDistributeDispatch") + std::to_string(param.packQuantType), dispatchParam); + } + + moeDistributeDispatchNode.inTensorIds = {GetTensorIdx(tensorMap, "in_hiddenstates"), + GetTensorIdx(tensorMap, "in_selected_experts"), + GetTensorIdx(tensorMap, "in_expert_weight"), + GetTensorIdx(tensorMap, "in_padding_idx")}; + + moeDistributeDispatchNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_sorted_hiddenstates"), + GetTensorIdx(tensorMap, "intermediate_gmm0_deqscale"), + GetTensorIdx(tensorMap, "intermediate_idx"), + GetTensorIdx(tensorMap, "intermediate_group_list"), + GetTensorIdx(tensorMap, "intermediate_ep_recv_counts"), + GetTensorIdx(tensorMap, "intermediate_tp_recv_counts"), + GetTensorIdx(tensorMap, "intermediate_expand_expert_weight"), + }; + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + opGraph.nodes.push_back(moeDistributeDispatchNode); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsAfterComm(opGraph)); + ATB_SPEED_LOG_DEBUG("MoeDistributeDispatch calculation success"); + return atb::NO_ERROR; +} + +template +atb::Status SetMoeDistributeCombineParam(CombineParam &combineParam, const MoeMlpParam ¶m) +{ + combineParam.epRankId = param.moeEpRank; + combineParam.epRankSize = param.moeEpSize; + combineParam.epCommName = param.moeEpDomain; + combineParam.moeExpertNum = param.numOfExperts; + combineParam.localMoeExpertNum = param.numOfDeviceExperts; + combineParam.maxDecodeDpTokenSize = param.maxDecodeDpTokenSize; + combineParam.sharedExpertRankNum = param.numDanglingSharedExperts; + combineParam.topk = param.topk; + const char *hcclIntraPcie = std::getenv("HCCL_INTRA_PCIE_ENABLE"); + const char *hcclIntraRoce = std::getenv("HCCL_INTRA_ROCE_ENABLE"); + if (hcclIntraPcie != nullptr && hcclIntraRoce != nullptr + && std::string(hcclIntraPcie) == "1" && std::string(hcclIntraRoce) == "0") { + combineParam.commQuantMode = 2; // 2: 量化模式2 + } + return atb::NO_ERROR; +} + +atb::Status CreateMoeDistributeCombine( + std::map &tensorMap, const MoeMlpParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node moeDistributeCombineNode; + if (param.enableDispatchCombineV2) { + atb_speed::common::MoeDistributeCombineV2Param combineParam; + SetMoeDistributeCombineParam(combineParam, param); + moeDistributeCombineNode.operation = \ + new atb_speed::common::MoeDistributeCombineV2Operation("MoeDistributeCombineV2", combineParam); + } else { + atb_speed::common::MoeDistributeCombineParam combineParam; + SetMoeDistributeCombineParam(combineParam, param); + moeDistributeCombineNode.operation = \ + new atb_speed::common::MoeDistributeCombineOperation("MoeDistributeCombine", combineParam); + } + + moeDistributeCombineNode.outTensorIds = {GetTensorIdx(tensorMap, "out_moe_mlp_result")}; + moeDistributeCombineNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_mlp_out"), + GetTensorIdx(tensorMap, "in_selected_experts"), + GetTensorIdx(tensorMap, "intermediate_idx"), + GetTensorIdx(tensorMap, "intermediate_ep_recv_counts"), + GetTensorIdx(tensorMap, "in_expert_weight"), + GetTensorIdx(tensorMap, "intermediate_tp_recv_counts"), + GetTensorIdx(tensorMap, "intermediate_expand_expert_weight"), + }; + + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + opGraph.nodes.push_back(moeDistributeCombineNode); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsAfterComm(opGraph)); + ATB_SPEED_LOG_DEBUG("MoeDistributeDispatch calculation success"); + return atb::NO_ERROR; +} + +// Op5 - Gather1 +atb::Status CreateGather1(std::map &tensorMap, + std::shared_ptr batchDimPtr, atb::GraphParam &opGraph) +{ + atb::Node gatherNode1; + atb::infer::GatherParam gatherParam; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatherParam, &gatherNode1.operation)); + gatherNode1.inTensorIds = {GetTensorIdx(tensorMap, "in_expert_weight"), + GetTensorIdx(tensorMap, "intermediate_weight_idx")}; + gatherNode1.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_sorted_weight")}; + gatherNode1.inTensorReshapeFuncs.resize(gatherNode1.inTensorIds.size()); + gatherNode1.inTensorReshapeFuncs[0] = [batchDimPtr](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + opGraph.nodes.push_back(gatherNode1); + ATB_SPEED_LOG_DEBUG("Gather1 calculation success"); + return atb::NO_ERROR; +} + +// Op6 - ElewiseMul1 +atb::Status CreateElewiseMul1(std::map &tensorMap, + std::shared_ptr batchDimPtr, atb::GraphParam &opGraph) +{ + atb::Node weightMulNode; + atb::infer::ElewiseParam elewiseParam; + elewiseParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(elewiseParam, &weightMulNode.operation)); + weightMulNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_mlp_out"), + GetTensorIdx(tensorMap, "intermediate_sorted_weight")}; + weightMulNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_mlp_out_weighted")}; + weightMulNode.inTensorReshapeFuncs.resize(weightMulNode.inTensorIds.size()); + weightMulNode.inTensorReshapeFuncs[1] = [batchDimPtr](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2:number of dimensions of the new shape + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = 1; + }; + opGraph.nodes.push_back(weightMulNode); + ATB_SPEED_LOG_DEBUG("ElewiseMul1 calculation success"); + return atb::NO_ERROR; +} + +// Op7 - Argsort +atb::Status CreateArgsort(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node argsortNode; + atb::infer::GatingParam gatingParam; + gatingParam.topkExpertNum = 1; + gatingParam.cumSumNum = 0; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatingParam, &argsortNode.operation)); + argsortNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_weight_idx"), + GetTensorIdx(tensorMap, "in_expert_array")}; + argsortNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_dummy_zero"), + GetTensorIdx(tensorMap, "intermediate_dummy_one"), + GetTensorIdx(tensorMap, "intermediate_rev_idx")}; + opGraph.nodes.push_back(argsortNode); + ATB_SPEED_LOG_DEBUG("Argsort calculation success"); + return atb::NO_ERROR; +} + +// Op8 - Gather2 +atb::Status CreateGather2(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node gatherNode2; + atb::infer::GatherParam gatherParam; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(gatherParam, &gatherNode2.operation)); + gatherNode2.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_mlp_out_weighted"), + GetTensorIdx(tensorMap, "intermediate_rev_idx")}; + gatherNode2.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_rev_sorted_hiddenstates")}; + opGraph.nodes.push_back(gatherNode2); + ATB_SPEED_LOG_DEBUG("Cather2 calculation success"); + return atb::NO_ERROR; +} + +// Op9 - Reduction +atb::Status CreateReduction(std::map &tensorMap, + const MoeMlpParam ¶m, std::shared_ptr batchDimPtr, atb::GraphParam &opGraph) +{ + CHECK_PARAM_NE(param.topk, 0); + atb::Node reduceNode; + atb::infer::ReduceParam reduceParam; + reduceParam.reduceType = atb::infer::ReduceParam::ReduceType::REDUCE_SUM; + reduceParam.axis = {1}; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(reduceParam, &reduceNode.operation)); + reduceNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_rev_sorted_hiddenstates")}; + reduceNode.outTensorIds = {GetTensorIdx(tensorMap, "out_moe_mlp_result")}; + reduceNode.inTensorReshapeFuncs.resize(reduceNode.inTensorIds.size()); + reduceNode.inTensorReshapeFuncs[0] = [batchDimPtr, param](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 3; // 3:number of dimensions of the new shape + newShape.dims[0] = oldShape.dims[0] / param.topk; + newShape.dims[1] = param.topk; + newShape.dims[2] = oldShape.dims[1]; // 2:the third dimension of the new shape + }; + opGraph.nodes.push_back(reduceNode); + ATB_SPEED_LOG_DEBUG("Reduction calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateActivationBlock(std::map &tensorMap, + const MoeMlpParam ¶m, atb::GraphParam &opGraph) +{ + if (param.supportSwiGLU) { + if (param.enableSwigluQuant) { + CHECK_OPERATION_STATUS_RETURN(CreateActivationQuant(tensorMap, opGraph)); + } else { + CHECK_OPERATION_STATUS_RETURN(CreateActivation(tensorMap, opGraph)); + } + } else { + CHECK_OPERATION_STATUS_RETURN(CreateSplit(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateActivationO(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateElewiseMul(tensorMap, opGraph)); + } + + ATB_SPEED_LOG_DEBUG("ActivationBlock calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateRecord(const MoeMlpParam ¶m, atb::GraphParam &opGraph, + atb_speed::EventAction eventAction, const std::string &cvKey) +{ + if (param.enableCVOverlap) { + atb::Node recordNode; + recordNode.inTensorIds = {}; + recordNode.outTensorIds = {}; + CHECK_OPERATION_STATUS_RETURN(atb_speed::EventManager::GetInstance().RecordEvent( + recordNode.operation, + eventAction, + cvKey)); + opGraph.nodes.push_back(recordNode); + ATB_SPEED_LOG_DEBUG("Record event success"); + } + return atb::NO_ERROR; +} + +atb::Status CreateWait(const MoeMlpParam ¶m, atb::GraphParam &opGraph, + atb_speed::EventAction eventAction, const std::string &cvKey) +{ + if (param.enableCVOverlap) { + atb::Node waitNode; + waitNode.inTensorIds = {}; + waitNode.outTensorIds = {}; + CHECK_OPERATION_STATUS_RETURN(atb_speed::EventManager::GetInstance().WaitEvent( + waitNode.operation, + eventAction, + cvKey)); + opGraph.nodes.push_back(waitNode); + ATB_SPEED_LOG_DEBUG("Wait event success"); + } + return atb::NO_ERROR; +} + +atb::Status CreateMoeDistribute( + std::map &tensorMap, + const MoeMlpParam ¶m, + atb::GraphParam &opGraph) +{ + std::shared_ptr batchDimPtr = std::make_shared(0); + bool isGMMSwigluQuant = IsGMMSwigluQuant(CalcUpGmmQuantType(param), param); + CHECK_OPERATION_STATUS_RETURN(CreateMoeDistributeDispatch(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateGmm(tensorMap, opGraph, param)); + if (!isGMMSwigluQuant) { + CHECK_OPERATION_STATUS_RETURN(CreateActivationBlock(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateGmm1(tensorMap, opGraph, param)); + } + CHECK_OPERATION_STATUS_RETURN(CreateMoeDistributeCombine(tensorMap, param, opGraph)); + return atb::NO_ERROR; +} + +atb::Status CreateFusedRouting( + std::map &tensorMap, + const MoeMlpParam ¶m, + atb::GraphParam &opGraph) +{ + std::shared_ptr batchDimPtr = std::make_shared(0); + bool isGMMSwigluQuant = IsGMMSwigluQuant(CalcUpGmmQuantType(param), param); + if (param.enableCVOverlap) { + CHECK_OPERATION_STATUS_RETURN(CreateWait( + param, opGraph, atb_speed::EventAction::POP, atb_speed::common::CUBE_CONTROL)); + } + if (param.enableInitQuant) { + CHECK_OPERATION_STATUS_RETURN(CreateInitRoutingQuant(tensorMap, param, opGraph)); + } else { + CHECK_OPERATION_STATUS_RETURN(CreateInitRouting(tensorMap, param, opGraph)); + } + if (param.hasMoeEp) { + CHECK_OPERATION_STATUS_RETURN(CreateComputeExpertSlice(tensorMap, param, opGraph)); + } + if (param.enableInitRoutingCutoff) { + CHECK_OPERATION_STATUS_RETURN(CreateExpandedXLen(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateGroupListFilter(tensorMap, param, opGraph)); + } + CHECK_OPERATION_STATUS_RETURN(CreateCast(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateGmm(tensorMap, opGraph, param)); + if (param.enableCVOverlap) { + CHECK_OPERATION_STATUS_RETURN(CreateWait( + param, opGraph, atb_speed::EventAction::POP, atb_speed::common::VECTOR_CONTROL)); + CHECK_OPERATION_STATUS_RETURN(CreateRecord( + param, opGraph, atb_speed::EventAction::POP, atb_speed::common::CUBE_CONTROL)); + } + if (!isGMMSwigluQuant) { + CHECK_OPERATION_STATUS_RETURN(CreateActivationBlock(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateGmm1(tensorMap, opGraph, param)); + } + CHECK_OPERATION_STATUS_RETURN(CreateMoeTokenUnpermute(tensorMap, param, opGraph)); + if (param.enableCVOverlap) { + CHECK_OPERATION_STATUS_RETURN(CreateWait( + param, opGraph, atb_speed::EventAction::POP, atb_speed::common::CUBE_CONTROL)); + } + return atb::NO_ERROR; +} + +atb::Status CreateDefault( + std::map &tensorMap, + const MoeMlpParam ¶m, + atb::GraphParam &opGraph) +{ + std::shared_ptr batchDimPtr = std::make_shared(0); + bool isGMMSwigluQuant = IsGMMSwigluQuant(CalcUpGmmQuantType(param), param); + CHECK_OPERATION_STATUS_RETURN(CreateGating(tensorMap, param, batchDimPtr, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateGather0(tensorMap, opGraph)); + if (param.packQuantType != atb_speed::common::PackQuantType::ALL_FP && Is310P()) { + // 310P QuantGMMDequantOperation + CHECK_OPERATION_STATUS_RETURN(CreateQuantGMMDequant(tensorMap, opGraph, param)); + CHECK_OPERATION_STATUS_RETURN(CreateActivationBlock(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateQuantGMMDequant1(tensorMap, opGraph, param)); + } else { + CHECK_OPERATION_STATUS_RETURN(CreateGmm(tensorMap, opGraph, param)); + if (!isGMMSwigluQuant) { + CHECK_OPERATION_STATUS_RETURN(CreateActivationBlock(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateGmm1(tensorMap, opGraph, param)); + } + } + CHECK_OPERATION_STATUS_RETURN(CreateGather1(tensorMap, batchDimPtr, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateElewiseMul1(tensorMap, batchDimPtr, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateArgsort(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateGather2(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateReduction(tensorMap, param, batchDimPtr, opGraph)); + return atb::NO_ERROR; +} + +atb::Status CreateMoeMlpOperation(const MoeMlpParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + opGraph.name = "MoeMlp"; + + std::map tensorMap = ConstructTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + ATB_SPEED_LOG_DEBUG("opGraph.inTensorNum " << opGraph.inTensorNum); + ATB_SPEED_LOG_DEBUG("opGraph.outTensorNum " << opGraph.outTensorNum); + ATB_SPEED_LOG_DEBUG("opGraph.internalTensorNum" << opGraph.internalTensorNum); + + if (param.hasMoeEp && !param.enableMoeDistribute) { + CHECK_OPERATION_STATUS_RETURN(CreateExpertFilter(tensorMap, param, opGraph)); + } + if (param.enableMoeDistribute) { + CHECK_OPERATION_STATUS_RETURN(CreateMoeDistribute(tensorMap, param, opGraph)); + } else if (param.enableFusedRouting) { + CHECK_OPERATION_STATUS_RETURN(CreateFusedRouting(tensorMap, param, opGraph)); + } else { + CHECK_OPERATION_STATUS_RETURN(CreateDefault(tensorMap, param, opGraph)); + } + + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + if (param.enableExpertCumSumOutput) { + outTensorDescs.at(1) = atb::TensorDesc{}; + outTensorDescs.at(1).format = ACL_FORMAT_ND; + outTensorDescs.at(1).shape.dimNum = 1; + outTensorDescs.at(1).dtype = ACL_INT64; + outTensorDescs.at(1).shape.dims[0] = param.numOfDeviceExperts; + } + + return atb::NO_ERROR; + }; + CREATE_OPERATION(opGraph, operation); + return atb::NO_ERROR; +} +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/moe_mlp.h b/tests/proftest/layer_test_framework/operations/fusion/moe/moe_mlp.h new file mode 100644 index 00000000..4b6b31a4 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/moe_mlp.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_MODELS_MOE_MLP_OPERATION_H +#define ATB_SPEED_MODELS_MOE_MLP_OPERATION_H +#include +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/log.h" +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/linear/linear_parallel.h" +#include "operations/fusion/norm/norm_linear.h" + +namespace atb_speed { +namespace common { +/// The non-deepseek models do not have the scaledTopk feature enabled by default +constexpr int TOPK_SCALE_DEFAULT_CLOSE = -1; +struct MoeMlpParam { + bool transpose = true; /// A flag indicating whether matrecies need to be transpose for matrix multiplications + bool supportSwiGLU = true; /// A flag indicating whether the device supports SwiGlu operator + bool shiftedTopK = false; /// A flag indicating whether or not to shift topk + int32_t topk = 2; /// The number of experts selected for each token + int scaledTopk = -1; /// The number of experts selected for each token after scale + bool enableInitRoutingCutoff = false; /// A flag indicating whether to use scaled topk option + int gmmQuantType = 0; /// Quantization type of the Gourped linear transformation + uint32_t numOfExperts = 8; /// The total number of experts utilized by the model + uint32_t numOfDeviceExperts = 8; /// The number of experts loaded to the device + std::vector moeLinearQuantType = {}; /// The list of quantization types of linear operations in MoE graph + std::vector deviceExpert = {0, 1, 2, 3, 4, 5, 6, 7}; /// The list of experts loaded on the device + int expertParallelDegree = 0; /// The specific realization of expert parallelism strategy utilized by the model + bool hasBias = false; /// A flag indicating whether there are bias to the linear operation weights + bool isBF16 = false; /// A flag indicating whether the model runs on bfloat16 + bool gateUpTransposeB = false; /// A flag indicating whether the B matrix of gateup operation should be transposed + bool downTransposeB = false; /// A flag indicating whether the B matrix of down operation should be transposed + bool enableFusedRouting = false; /// A flag indicating whether to use integrated routing operators + /// A flag indicating whether or not to use integrated GMM+Swiglu+quant operators. + bool enableGMMSwigluQuant = false; + /// A flag indicating whether or not to use fused atb GMM+Swiglu+quant operators instead of aclnn. + bool enableAtlasGMMFused = false; + bool enableInitQuant = false; /// A flag indicating whether to use routing-quant integrated operator + bool enableSwigluQuant = false; /// A flag indicating whether to use swiglu-quant integrated operator + bool enableMoeParallel = false; /// A flag indicating whether the model use Moe parallel + bool enableCVOverlap = false; /// A flag indicating whether the model use cube and vector parallel + bool hasMoeEp = false; /// A flag indicating whether the model uses expert parallelism + bool enableDispatchCombineV2 = false; /// A flag indicating whether to use dispatch_v2 and combine_v2 + int packQuantType = atb_speed::common::PackQuantType::ALL_FP; /// The quantization type of the packed weights + /// The quantization type used to facilitate the calculation of the quantization type of the linear operation + int denseQuantType = atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED; + /// The group size used for dequantizing the weight tensor in the per-group quantization approach + int quantGroupSize = 0; + + std::string backend = "hccl"; + int moeEpRank = 0; + int moeEpSize = 1; + int maxDecodeDpTokenSize = 0; + std::string moeEpDomain = ""; + std::string moeEpRankTableFile = ""; + bool hasMlpTp = false; + int mlpTpRank = 0; + int mlpTpSize = 1; + std::string mlpTpDomain = ""; + std::string mlpTpRankTableFile = ""; + + bool enableMoeDistribute = false; + bool enableExpertCumSumOutput = false; + bool enableGatingDp = false; + int64_t numDanglingSharedExperts = 0; + uint32_t numOfRedundantExpert = 0; +}; + +/// This funciton creates a sub-graph that performs the FFN of a model with MoE structure +atb::Status CreateMoeMlpOperation(const MoeMlpParam ¶m, atb::Operation **operation); +} +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/moe_shared_expert.cpp b/tests/proftest/layer_test_framework/operations/fusion/moe/moe_shared_expert.cpp new file mode 100644 index 00000000..57585d0b --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/moe_shared_expert.cpp @@ -0,0 +1,367 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "atb_speed/log.h" +#include "moe_shared_expert.h" + +namespace atb_speed { +namespace common { + +std::map> GetSharedExpertInTensorCandidates() +{ + std::map> sharedExpertInTensorCandidates = { + {"default", { + "in_hidden_states", + "in_mlp_gate_up_weight", "in_mlp_gate_up_bias", "in_mlp_gate_up_descale", + "in_mlp_gate_up_offset", "in_mlp_gate_up_scale", "in_mlp_gate_up_compress_idx", + "in_mlp_down_weight", "in_mlp_down_bias", "in_mlp_down_descale", "in_mlp_down_offset", + "in_mlp_down_scale", "in_mlp_down_compress_idx", + "in_shared_expert_gate_weight", "in_shared_expert_gate_bias", "in_shared_expert_gate_descale", + "in_shared_expert_gate_offset", "in_shared_expert_gate_scale","in_shared_expert_gate_compress_idx" + } + } + }; + return sharedExpertInTensorCandidates; +}; + +std::map> GetSharedExpertIntermediateTensorCandidates() +{ + std::map> sharedExpertIntermediateTensorCandidates = { + {"default", + {"intermediate_matmul_gate_up_out", "intermediate_hidden_states"} + }, + {"enable_swiglu_quant_for_shared_experts", + {"swiglu_quant_sacle_out"} + }, + {"not_support_swiglu", + {"intermediate_matmul_gate_out", "intermediate_swish_out", "intermediate_matmul_up_out"} + }, + {"has_shared_expert_gate", + { + "intermediate_shared_expert_out", "intermediate_shared_expert_gate_logits", + "intermediate_shared_expert_weight" + } + } + }; + return sharedExpertIntermediateTensorCandidates; +}; + +std::map ConstructTensorMap(const SharedExpertParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum) +{ + auto sharedExpertInTensorCandidates = GetSharedExpertInTensorCandidates(); + auto sharedExpertIntermediateTensorCandidates = GetSharedExpertIntermediateTensorCandidates(); + + std::vector inTensorList = {}; + std::vector intermediateTensorList = {}; + std::vector outTensorList = {"out"}; + + // 添加默认的Tensor + AddTensorToList(sharedExpertInTensorCandidates, "default", inTensorList); + AddTensorToList(sharedExpertIntermediateTensorCandidates, "default", intermediateTensorList); + + // 如果使用SwiGLUQuant算子 + if (param.enableSwiGLUQuantForSharedExperts) { + AddTensorToList(sharedExpertIntermediateTensorCandidates, + "enable_swiglu_quant_for_shared_experts", intermediateTensorList); + } + // 如果不支持SwiGLU + if (!param.supportSwiGLU) { + AddTensorToList(sharedExpertIntermediateTensorCandidates, + "not_support_swiglu", intermediateTensorList); + } + // 如果支持shared expert gate + if (param.hasSharedExpertGate) { + AddTensorToList(sharedExpertIntermediateTensorCandidates, + "has_shared_expert_gate", intermediateTensorList); + } + + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + internalTensorNum = intermediateTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, intermediateTensorList); +} + +// Expert Ops +atb::Status CreateLinear(const SharedExpertParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) +{ + atb::Node linearNode; + atb_speed::common::FusionLinearParam linearParam; + linearParam.hasBias = false; + linearParam.isBF16 = param.isBF16; + linearParam.transposeType = param.mlpLinearTransposeType.at(SHARED_MOE_GATE_LINEAR_INDEX); + linearParam.quantType = GetLinearQuantType( + param.denseQuantType == atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED \ + ? param.packQuantType : param.denseQuantType, + param.mlpLinearQuantType[SHARED_MOE_GATE_LINEAR_INDEX], false); + linearParam.quantGroupSize = param.quantGroupSize; + if (param.enableCVOverlap) { + linearParam.enableCVOverlap = true; + } + CHECK_OPERATION_STATUS_RETURN(FusionLinear(linearParam, &linearNode.operation)); + linearNode.inTensorIds = { + GetTensorIdx(tensorMap, "in_hidden_states"), + GetTensorIdx(tensorMap, "in_mlp_gate_up_weight"), + GetTensorIdx(tensorMap, "in_mlp_gate_up_scale"), + GetTensorIdx(tensorMap, "in_mlp_gate_up_offset"), + GetTensorIdx(tensorMap, "in_mlp_gate_up_descale"), + GetTensorIdx(tensorMap, "in_mlp_gate_up_bias"), + GetTensorIdx(tensorMap, "in_mlp_gate_up_compress_idx"), + }; + linearNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_matmul_gate_up_out")}; + opGraph.nodes.push_back(linearNode); + ATB_SPEED_LOG_DEBUG("Gate up projection calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateSplit(atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node splitNode; + atb::infer::SplitParam splitParam = {1, 2, {}}; + CreateOperation(splitParam, &splitNode.operation); + splitNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_matmul_gate_up_out")}; + splitNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_matmul_gate_out"), + GetTensorIdx(tensorMap, "intermediate_matmul_up_out")}; + opGraph.nodes.push_back(splitNode); + ATB_SPEED_LOG_DEBUG("Split calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateActivation(atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node swishNode; + atb::infer::ActivationParam activationParam; + activationParam.activationType = atb::infer::ActivationType::ACTIVATION_SWISH; + CreateOperation(activationParam, &swishNode.operation); + swishNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_matmul_gate_out")}; + swishNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_swish_out")}; + opGraph.nodes.push_back(swishNode); + ATB_SPEED_LOG_DEBUG("Swish calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateElewiseMul1(atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node mulNode; + atb::infer::ElewiseParam elewiseParam; + elewiseParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; + CreateOperation(elewiseParam, &mulNode.operation); + mulNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_swish_out"), + GetTensorIdx(tensorMap, "intermediate_matmul_up_out")}; + mulNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_hidden_states")}; + opGraph.nodes.push_back(mulNode); + ATB_SPEED_LOG_DEBUG("ElewiseMul1 success"); + return atb::NO_ERROR; +} + +// SwiGlu = Split + Activation + ElewiseMatmul +atb::Status CreateSwiGLU(atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node swishNode; + atb::infer::ActivationParam activationParam; + activationParam.activationType = atb::infer::ActivationType::ACTIVATION_SWIGLU_FORWARD; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(activationParam, &swishNode.operation)); + swishNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_matmul_gate_up_out")}; + swishNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_hidden_states")}; + opGraph.nodes.push_back(swishNode); + ATB_SPEED_LOG_DEBUG("Activation calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateSwiGLUQuant(atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node swishNode; + atb::infer::SwigluQuantParam activationParam; + activationParam.quantType = atb::infer::SwigluQuantParam::QuantType::QUANT_TYPE_PER_TOKEN; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(activationParam, &swishNode.operation)); + swishNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_matmul_gate_up_out")}; + swishNode.outTensorIds = { + GetTensorIdx(tensorMap, "intermediate_hidden_states"), + GetTensorIdx(tensorMap, "swiglu_quant_sacle_out")}; + opGraph.nodes.push_back(swishNode); + ATB_SPEED_LOG_DEBUG("Activation calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateLinearDown(const SharedExpertParam ¶m, atb::GraphParam &opGraph, + std::map &tensorMap) +{ + atb::Node linearDownNode; + atb_speed::common::FusionLinearParam linearDownParam; + linearDownParam.hasBias = false; + linearDownParam.enableSwiGLUQuantForSharedExperts = param.enableSwiGLUQuantForSharedExperts; + linearDownParam.isBF16 = param.isBF16; + linearDownParam.transposeType = param.mlpLinearTransposeType.at(SHARED_MOE_DOWN_LINEAR_INDEX); + linearDownParam.quantType = GetLinearQuantType( + param.denseQuantType == atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED \ + ? param.packQuantType : param.denseQuantType, + param.mlpLinearQuantType[SHARED_MOE_DOWN_LINEAR_INDEX], false); + linearDownParam.quantGroupSize = param.quantGroupSize; + if (param.enableCVOverlap) { + linearDownParam.enableCVOverlap = true; + } + CHECK_OPERATION_STATUS_RETURN(FusionLinear(linearDownParam, &linearDownNode.operation)); + linearDownNode.inTensorIds = { + GetTensorIdx(tensorMap, "intermediate_hidden_states"), + GetTensorIdx(tensorMap, "in_mlp_down_weight"), + GetTensorIdx(tensorMap, "in_mlp_down_scale"), + GetTensorIdx(tensorMap, "in_mlp_down_offset"), + GetTensorIdx(tensorMap, "in_mlp_down_descale"), + GetTensorIdx(tensorMap, "in_mlp_down_bias"), + GetTensorIdx(tensorMap, "in_mlp_down_compress_idx"), + }; + if (param.enableSwiGLUQuantForSharedExperts) { + linearDownNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "swiglu_quant_sacle_out")); + } + if (param.hasSharedExpertGate) { + linearDownNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_shared_expert_out")}; + } else { + linearDownNode.outTensorIds = {GetTensorIdx(tensorMap, "out")}; + } + opGraph.nodes.push_back(linearDownNode); + ATB_SPEED_LOG_DEBUG("Projection down success"); + return atb::NO_ERROR; +} + +// Expert Gate Ops + +atb::Status CreateSharedExpertGate(const SharedExpertParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node sharedexpertgateNode; + atb_speed::common::FusionLinearParam linearParam; + linearParam.hasBias = false; + linearParam.hasBias = false; + linearParam.quantType = GetLinearQuantType( + param.denseQuantType == atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED \ + ? param.packQuantType : param.denseQuantType, + param.mlpLinearQuantType[SHARED_MOE_SHAREGATE_LINEAR_INDEX], false); + linearParam.quantGroupSize = param.quantGroupSize; + CHECK_OPERATION_STATUS_RETURN(FusionLinear(linearParam, &sharedexpertgateNode.operation)); + sharedexpertgateNode.inTensorIds = { + GetTensorIdx(tensorMap, "in_hidden_states"), + GetTensorIdx(tensorMap, "in_shared_expert_gate_weight"), + GetTensorIdx(tensorMap, "in_shared_expert_gate_scale"), + GetTensorIdx(tensorMap, "in_shared_expert_gate_offset"), + GetTensorIdx(tensorMap, "in_shared_expert_gate_descale"), + GetTensorIdx(tensorMap, "in_shared_expert_gate_bias"), + GetTensorIdx(tensorMap, "in_shared_expert_gate_compress_idx"), + }; + sharedexpertgateNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_shared_expert_gate_logits")}; + opGraph.nodes.push_back(sharedexpertgateNode); + ATB_SPEED_LOG_DEBUG("Shared Expert Gate calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateActivationSigmoid(atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node sigmoidNode; + atb::infer::ActivationParam activationParam; + activationParam.activationType = atb::infer::ActivationType::ACTIVATION_SIGMOID; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(activationParam, &sigmoidNode.operation)); + sigmoidNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_shared_expert_gate_logits")}; + sigmoidNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_shared_expert_weight")}; + opGraph.nodes.push_back(sigmoidNode); + ATB_SPEED_LOG_DEBUG("Activation Sigmoid calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateElewiseMul2(atb::GraphParam &opGraph, std::map &tensorMap) +{ + atb::Node sigmoidMulNode; + atb::infer::ElewiseParam elewiseParam; + elewiseParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(elewiseParam, &sigmoidMulNode.operation)); + sigmoidMulNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_shared_expert_out"), + GetTensorIdx(tensorMap, "intermediate_shared_expert_weight")}; + sigmoidMulNode.outTensorIds = {GetTensorIdx(tensorMap, "out")}; + opGraph.nodes.push_back(sigmoidMulNode); + ATB_SPEED_LOG_DEBUG("ElewiseMul2 calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateActivationBlock(const SharedExpertParam ¶m, + atb::GraphParam &opGraph, std::map &tensorMap) +{ + if (param.supportSwiGLU) { + if (param.enableSwiGLUQuantForSharedExperts) { + CHECK_OPERATION_STATUS_RETURN(CreateSwiGLUQuant(opGraph, tensorMap)); + } else { + CHECK_OPERATION_STATUS_RETURN(CreateSwiGLU(opGraph, tensorMap)); + } + } else { + CHECK_OPERATION_STATUS_RETURN(CreateSplit(opGraph, tensorMap)); + CHECK_OPERATION_STATUS_RETURN(CreateActivation(opGraph, tensorMap)); + CHECK_OPERATION_STATUS_RETURN(CreateElewiseMul1(opGraph, tensorMap)); + } + + ATB_SPEED_LOG_DEBUG("ActivationBlock calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateSharedExpertOperation(const SharedExpertParam ¶m, atb::Operation **operation) +{ + std::shared_ptr batchDimPtr = std::make_shared(0); + + atb::GraphParam opGraph; + opGraph.name = "SharedExpert"; + + std::map tensorMap = ConstructTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + + if (param.enableCVOverlap) { + CHECK_OPERATION_STATUS_RETURN(atb_speed::common::CreateWaitWithoutNodeId( + opGraph, atb_speed::EventAction::PUSH, atb_speed::common::CV_START)); + } + + CHECK_OPERATION_STATUS_RETURN(CreateLinear(param, opGraph, tensorMap)); + + if (param.enableCVOverlap) { + CHECK_OPERATION_STATUS_RETURN(atb_speed::common::CreateRecordWithoutNodeId( + opGraph, atb_speed::EventAction::PUSH, atb_speed::common::CUBE_CONTROL)); + CHECK_OPERATION_STATUS_RETURN(atb_speed::common::CreateWaitWithoutNodeId( + opGraph, atb_speed::EventAction::PUSH, atb_speed::common::VECTOR_CONTROL)); + } + + CHECK_OPERATION_STATUS_RETURN(CreateActivationBlock(param, opGraph, tensorMap)); + CHECK_OPERATION_STATUS_RETURN(CreateLinearDown(param, opGraph, tensorMap)); + + if (param.enableCVOverlap) { + CHECK_OPERATION_STATUS_RETURN(atb_speed::common::CreateRecordWithoutNodeId( + opGraph, atb_speed::EventAction::PUSH, atb_speed::common::CUBE_CONTROL)); + } + + if (param.hasSharedExpertGate) { + CHECK_OPERATION_STATUS_RETURN(CreateSharedExpertGate(param, opGraph, tensorMap)); + CHECK_OPERATION_STATUS_RETURN(CreateActivationSigmoid(opGraph, tensorMap)); + CHECK_OPERATION_STATUS_RETURN(CreateElewiseMul2(opGraph, tensorMap)); + } + + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(GetTensorIdx(tensorMap, "in_hidden_states")); + return atb::NO_ERROR; + }; + CREATE_OPERATION(opGraph, operation); + return atb::NO_ERROR; +} + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/moe_shared_expert.h b/tests/proftest/layer_test_framework/operations/fusion/moe/moe_shared_expert.h new file mode 100644 index 00000000..011cc1f3 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/moe_shared_expert.h @@ -0,0 +1,64 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +* +* * Licensed under the Apache License, Version 2.0 (the "License"); +* * you may not use this file except in compliance with the License. +* * You may obtain a copy of the License at +* * +* * http://www.apache.org/licenses/LICENSE-2.0 +* * +* * Unless required by applicable law or agreed to in writing, software +* * distributed under the License is distributed on an "AS IS" BASIS, +* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* * See the License for the specific language governing permissions and +* * limitations under the License. +* */ + +#ifndef ATB_SPEED_MODELS_MOE_SHARED_EXPERT_H +#define ATB_SPEED_MODELS_MOE_SHARED_EXPERT_H + +#include +#include "atb_speed/utils/operation_util.h" +#include "operations/fusion/utils.h" +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/linear/linear_parallel.h" +#include "operations/fusion/norm/norm_linear.h" + +namespace atb_speed { +namespace common { + +constexpr uint64_t SHARED_MOE_GATE_LINEAR_INDEX = 0; +constexpr uint64_t SHARED_MOE_UP_LINEAR_INDEX = 1; +constexpr uint64_t SHARED_MOE_DOWN_LINEAR_INDEX = 2; +constexpr uint64_t SHARED_MOE_SHAREGATE_LINEAR_INDEX = 3; + +struct SharedExpertParam { + bool transposeGateup = true; /// A flag indicating whether the B matrix of gateup operation should be transposed + bool transposeDown = false; /// A flag indicating whether the B matrix of down operation should be transposed + bool hasSharedExpertGate = true; /// A flag indicating whether there is routing mechanism for shared experts + bool enableSwiGLUQuantForSharedExperts = false; + bool supportSwiGLU = true; /// A flag indicating whether the device supports SwiGlu operator + bool isBF16 = false; /// A flag indicating whether the model runs on bfloat16 + bool enableCVOverlap = false; /// A flag indicating whether the model use cube and vector parallel + int packQuantType = atb_speed::common::PackQuantType::ALL_FP; /// The quantization type of the packed weights + int quantGroupSize = 0; /// Group size of per-group quantization + /// The quantization type used to facilitate the calculation of the quantization type of the linear operation + int denseQuantType = atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED; + /// A list of quantization types of the linear operations in this sub-graph + std::vector mlpLinearQuantType = {}; + /// A list of flags indicating whether the B matrecies of the linear operations should be tranpsoed + std::vector mlpLinearTransposeType = {}; +}; + +/// This funciton constructs the tensor map of this sub-graph. +/// \return A flag that indicates whether the opertaion is successfully created or not. +std::map ConstructTensorMap( + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum); + +/// This function creates a sub-graph that performance the shared-experts calculation on the input. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateSharedExpertOperation( + const SharedExpertParam ¶m, atb::Operation **operation); +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/sparse_moe.cpp b/tests/proftest/layer_test_framework/operations/fusion/moe/sparse_moe.cpp new file mode 100644 index 00000000..0a643441 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/sparse_moe.cpp @@ -0,0 +1,1074 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +* +* * Licensed under the Apache License, Version 2.0 (the "License"); +* * you may not use this file except in compliance with the License. +* * You may obtain a copy of the License at +* * +* * http://www.apache.org/licenses/LICENSE-2.0 +* * +* * Unless required by applicable law or agreed to in writing, software +* * distributed under the License is distributed on an "AS IS" BASIS, +* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* * See the License for the specific language governing permissions and +* * limitations under the License. +* */ +#include "sparse_moe.h" +#include +#include +#include +#include "moe_mlp.h" +#include "device_limited_routing.h" +#include "operations/fusion/moe/ep/dynamic_ep_moe.h" +#include "operations/aclnn/ops/moe_topk_softmax_operation.h" +#include "operations/aclnn/ops/vector_norm_operation.h" +#include "operations/aclnn/ops/std_operation.h" +#include "operations/aclnn/ops/sigmoid_operation.h" +#include "operations/aclnn/ops/matmul_operation.h" +#include "operations/aclnn/ops/concat_operation.h" +#include "atb_speed/base/event_manager.h" + +namespace atb_speed { +namespace common { + +const uint64_t NODE_SIZE_INCR_NORMALIZATION = 2; +static const uint64_t STREAM1 = 1; +static const uint64_t NUM1 = 1; +static const uint64_t NUM2 = 2; +static const uint64_t NUM3 = 3; +static const uint64_t NUM4 = 4; +static const uint64_t NUM5 = 5; +constexpr uint32_t TOPK_IN_NUM = 4; +constexpr uint32_t TOPK_IN3_DIM = 3; + +std::map> GetSparseMoeInTensorCandidates() +{ + std::map> moeMlpInTensorCandidates = { + {"default", { + "in_hiddenstates", "in_gate_weight", "in_gate_bias", "in_gate_descale", "in_gate_offset", + "in_gate_scale", "in_gate_compress_idx", "in_mlp_gateup_weight_expert", "in_mlp_gateup_bias_expert", + "in_mlp_gateup_descale_expert", "in_mlp_gateup_offset_expert", "in_mlp_gateup_scale_expert", + "in_mlp_gateup_compress_idx_expert", "in_mlp_down_weight_expert", + "in_mlp_down_bias_expert", "in_mlp_down_descale_expert", "in_mlp_down_offset_expert", + "in_mlp_down_scale_expert", "in_mlp_down_compress_idx_expert", "in_expert_array", + "in_expert_group", "in_one_hot", "in_zero_hot"} + }, + {"ep", { + "in_start_expert_idx", "in_device_expert_count", "in_padding_idx"} + }, + {"dynamic_ep", { + "in_dynamic_ep_idx", "in_moe_idx"} + }, + {"force_load_balance", { + "in_fake_topk"} + }, + {"epwb", { + "in_expert_routing_map"} + }, + {"gating_dp", { + "in_hiddenstates_slice", "in_attn_unpadding_idx"} + }, + {"fp32_gate_input", { + "in_hiddenstates_fp32"} + } + }; + return moeMlpInTensorCandidates; +} + +atb::Status ConstructATBGateMatmulTensorMap(const SparseMoeParam ¶m, std::vector &interTensorList) +{ + if (!param.enableFp32GateInput) { + interTensorList.push_back("intermediate_hiddenstates_fp32"); + } + interTensorList.push_back("intermediate_router_logits_fp32"); + return atb::NO_ERROR; +} + +atb::Status ConstructRoutingTensorMap(const SparseMoeParam ¶m, std::vector &interTensorList) +{ + if (param.enableATBGateMatmul) { + ConstructATBGateMatmulTensorMap(param, interTensorList); + } + if (!param.enableFusedTopk) { + interTensorList.push_back("intermediate_router_weights"); + interTensorList.push_back("intermediate_router_weights_topk"); + if (param.routingMethod == "noAuxTc") { + interTensorList.push_back("intermediate_router_weights_for_choice"); + interTensorList.push_back("intermediate_router_weights_topk_temp"); + } + if (param.useStdNorm) { + interTensorList.push_back("intermediate_router_logits_std"); + } + if (param.processLogits == "normalization" || + param.processLogits == "norm" || + param.processLogits == "normScaling") { + interTensorList.push_back("intermediate_router_weights_topk_reduced"); + interTensorList.push_back("intermediate_router_weights_topk_sumed"); + } else if (param.processLogits == "scaling") { + interTensorList.push_back("intermediate_router_weights_topk_reduced"); + } + } + if (param.processLogits != "none") { + interTensorList.push_back("intermediate_router_weights_topk_reduced_fp32"); + if (!param.enableMoeDistribute) { + interTensorList.push_back("intermediate_router_weights_topk_reduced_fp16"); + } + } + if (param.mixSharedRouting) { + interTensorList.push_back("intermediate_router_weights_topk_reduced_mix_shared"); + interTensorList.push_back("intermediate_selected_experts_mix_shared"); + } + return atb::NO_ERROR; +} + +std::map ConstructTensorMap( + const SparseMoeParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &interTensorNum) +{ + auto moeMlpInTensorCandidates = GetSparseMoeInTensorCandidates(); + std::vector inTensorList = {}; + std::vector interTensorList = { + "intermediate_router_logits", "intermediate_selected_experts"}; + std::vector outTensorList = {"out_moe_rout"}; + AddTensorToList(moeMlpInTensorCandidates, "default", inTensorList); + if (param.enableLoadBalance) { + AddTensorToList(moeMlpInTensorCandidates, "force_load_balance", inTensorList); + } + if (param.enableExpertCumSumOutput) { + outTensorList.push_back("out_gmm_cumsum_list"); + } + ConstructRoutingTensorMap(param, interTensorList); + if (param.hasMoeEp) { + AddTensorToList(moeMlpInTensorCandidates, "ep", inTensorList); + if (param.isDynamicEp) { + AddTensorToList(moeMlpInTensorCandidates, "dynamic_ep", inTensorList); + } + } + if (param.enableEPWB) { + AddTensorToList(moeMlpInTensorCandidates, "epwb", inTensorList); + interTensorList.push_back("intermediate_selected_experts_routed"); + } + if (param.enableGatingDp) { + AddTensorToList(moeMlpInTensorCandidates, "gating_dp", inTensorList); + interTensorList.push_back("intermediate_selected_experts_with_padding_all"); + interTensorList.push_back("intermediate_selected_experts_all"); + interTensorList.push_back("intermediate_router_weights_topk_reduced_with_padding_all"); + interTensorList.push_back("intermediate_router_weights_topk_reduced_all"); + } + if (param.enableGatingShift) { + interTensorList.push_back("intermediate_router_logits_split_1"); + interTensorList.push_back("intermediate_router_logits_split_2"); + interTensorList.push_back("intermediate_router_logits_shifted"); + } + if (param.enableFp32GateInput) { + AddTensorToList(moeMlpInTensorCandidates, "fp32_gate_input", inTensorList); + } + if (param.mixSharedRouting) { + inTensorList.push_back("mix_shared_routing_weight"); + inTensorList.push_back("mix_shared_routing_expert"); + } + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + interTensorNum = interTensorList.size(); + return GetTensorMap(inTensorList, outTensorList, interTensorList); +} + + +atb::Status CreateSparseMoemoeGate( + std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + atb::Node linearNode; + FusionLinearParam moeGateParam; + moeGateParam.transposeType = common::TRANSPOSE; + moeGateParam.hasBias = param.rounterHasBias; + moeGateParam.isBF16 = param.isBF16; + moeGateParam.quantType = atb_speed::common::GetLinearQuantType( + param.denseQuantType == atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED \ + ? param.packQuantType : param.denseQuantType, + param.moeLinearQuantType[SparseMoeIdx::ROUTER_IDX], false); + moeGateParam.quantGroupSize = 0; + CHECK_OPERATION_STATUS_RETURN(FusionLinear(moeGateParam, &linearNode.operation)); + linearNode.inTensorIds = {GetTensorIdx(tensorMap, "in_hiddenstates"), + GetTensorIdx(tensorMap, "in_gate_weight"), + GetTensorIdx(tensorMap, "in_gate_scale"), + GetTensorIdx(tensorMap, "in_gate_offset"), + GetTensorIdx(tensorMap, "in_gate_descale"), + GetTensorIdx(tensorMap, "in_gate_bias"), + GetTensorIdx(tensorMap, "in_gate_compress_idx")}; + linearNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_logits")}; + opGraph.nodes.push_back(linearNode); + ATB_SPEED_LOG_DEBUG("Router logits calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseMoemoeGateFp32Atb(std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + if (!param.enableFp32GateInput) { + atb::Node castUp; + atb::infer::ElewiseParam castUpParam; + castUpParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castUpParam.outTensorType = ACL_FLOAT; + castUp.inTensorIds = {GetTensorIdx(tensorMap, (param.enableGatingDp) ? + "in_hiddenstates_slice" : "in_hiddenstates")}; + castUp.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_hiddenstates_fp32")}; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castUpParam, &castUp.operation)); + opGraph.nodes.push_back(castUp); + } + + atb::Node linearNode; + atb::infer::LinearParam moeGateParam; + moeGateParam.hasBias = false; + moeGateParam.transposeB = true; + linearNode.inTensorIds = {GetTensorIdx(tensorMap, param.enableFp32GateInput ? "in_hiddenstates_fp32" : \ + "intermediate_hiddenstates_fp32"), + GetTensorIdx(tensorMap, "in_gate_weight")}; + linearNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_logits_fp32")}; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(moeGateParam, &linearNode.operation)); + if (param.enableGatingOverlap) { + atb::SetExecuteStreamId(linearNode.operation, STREAM1); + } + opGraph.nodes.push_back(linearNode); + + atb::Node castDown; + atb::infer::ElewiseParam castDownParam; + castDownParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castDownParam.outTensorType = param.isBF16 ? ACL_BF16 : ACL_FLOAT16; + castDown.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_logits_fp32")}; + castDown.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_logits")}; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castDownParam, &castDown.operation)); + opGraph.nodes.push_back(castDown); + + ATB_SPEED_LOG_DEBUG("Router logits calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseMoemoeGateFp32Aclnn(std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + atb::Node linearNode; + // using aclnn matmul + atb_speed::common::AclNNMatmulParam moeGateParam; + moeGateParam.hasBias = false; + moeGateParam.transposeB = true; + linearNode.operation = new atb_speed::common::MatmulOperation("SparseMoeGateNode", moeGateParam); + linearNode.inTensorIds = {GetTensorIdx(tensorMap, (param.enableGatingDp) ? + "in_hiddenstates_slice" : "in_hiddenstates"), + GetTensorIdx(tensorMap, "in_gate_weight")}; + linearNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_logits")}; + if (param.enableGatingOverlap) { + atb::SetExecuteStreamId(linearNode.operation, STREAM1); + } + opGraph.nodes.push_back(linearNode); + ATB_SPEED_LOG_DEBUG("Router logits calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseMoemoeGateFp32(std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + if (param.enableATBGateMatmul) { + ATB_SPEED_LOG_DEBUG("Create ATB Fp32 Gate Matmul"); + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoemoeGateFp32Atb(tensorMap, param, opGraph)); + } else { + ATB_SPEED_LOG_DEBUG("Create ACLNN Fp32 Gate Matmul"); + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoemoeGateFp32Aclnn(tensorMap, param, opGraph)); + } + return atb::NO_ERROR; +} + +atb::Status CreateSplit(std::map &tensorMap, const SparseMoeParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node splitNode; + atb::infer::SplitParam splitParam; + splitParam.splitDim = NUM1; + splitParam.splitNum = NUM2; + if (param.deviceExpert[0] == 0) { + splitParam.splitSizes={static_cast(NUM1), \ + static_cast(param.numOfExperts) - static_cast(NUM1)}; + } else { + splitParam.splitSizes={param.deviceExpert[0], static_cast(param.numOfExperts) - param.deviceExpert[0]}; + } + CREATE_OPERATION(splitParam, &splitNode.operation); + splitNode.inTensorIds = atb_speed::common::GetTensorIdxList(tensorMap, {"intermediate_router_logits"}); + splitNode.outTensorIds = atb_speed::common::GetTensorIdxList(tensorMap, {"intermediate_router_logits_split_1", + "intermediate_router_logits_split_2"}); + opGraph.nodes.push_back(splitNode); + return atb::NO_ERROR; +} + +atb::Status CreateConcat(std::map &tensorMap, const SparseMoeParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node concatNode; + atb::infer::ConcatParam catParam; + catParam.concatDim = -1; + CREATE_OPERATION(catParam, &concatNode.operation); + if (param.deviceExpert[0] == 0) { + concatNode.inTensorIds = atb_speed::common::GetTensorIdxList(tensorMap, + {"intermediate_router_logits_split_1", "intermediate_router_logits_split_2"}); + } else { + concatNode.inTensorIds = atb_speed::common::GetTensorIdxList(tensorMap, + {"intermediate_router_logits_split_2", "intermediate_router_logits_split_1"}); + } + concatNode.outTensorIds = atb_speed::common::GetTensorIdxList(tensorMap, {"intermediate_router_logits_shifted"}); + opGraph.nodes.push_back(concatNode); + return atb::NO_ERROR; +} + +atb::Status CreateFusedAddTopkDiv(std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + atb::Node fusedAddTopkDivNode; + atb::infer::FusedAddTopkDivParam fusedAddTopkDivParam; + fusedAddTopkDivParam.groupNum = param.numOfGroups; + fusedAddTopkDivParam.groupTopk = param.topkGroups[0]; + // 2: chosen number within each group + constexpr uint32_t chosenNumEachGroup = 2; + uint32_t numOfGroups = param.numOfGroups > 0 ? param.numOfGroups : 1; + fusedAddTopkDivParam.n = std::min(param.numOfExperts / numOfGroups, chosenNumEachGroup); + fusedAddTopkDivParam.k = param.num[0]; + fusedAddTopkDivParam.scale = param.routedScalingFactor; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(fusedAddTopkDivParam, &fusedAddTopkDivNode.operation)); + fusedAddTopkDivNode.inTensorIds = {GetTensorIdx(tensorMap, (param.enableGatingShift) ? \ + "intermediate_router_logits_shifted" : "intermediate_router_logits"), + GetTensorIdx(tensorMap, "in_gate_bias")}; + fusedAddTopkDivNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk_reduced_fp32"), + GetTensorIdx(tensorMap, "intermediate_selected_experts")}; + if (param.enableGatingOverlap) { + atb::SetExecuteStreamId(fusedAddTopkDivNode.operation, STREAM1); + } + opGraph.nodes.push_back(fusedAddTopkDivNode); + ATB_SPEED_LOG_DEBUG("FusedAddTopkDivOperation calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseMoeStd(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node stdNode; + stdNode.operation = new atb_speed::common::StdOperation("SparseMoeStdNode"); + stdNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_logits")}; + stdNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_logits_std")}; + opGraph.nodes.push_back(stdNode); + ATB_SPEED_LOG_DEBUG("Router logits std calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseMoeNorm(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node normNode; + atb::infer::ElewiseParam normParam; + normParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_REALDIV; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(normParam, &normNode.operation)); + normNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_logits"), + GetTensorIdx(tensorMap, "intermediate_router_logits_std")}; + normNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_logits")}; + opGraph.nodes.push_back(normNode); + ATB_SPEED_LOG_DEBUG("Router weights norm calculated success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseMoesoftMax( + std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + atb::Node softMaxNode; + atb::infer::SoftmaxParam softMaxParam; + softMaxParam.axes = param.axes; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(softMaxParam, &softMaxNode.operation)); + softMaxNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_logits")}; + softMaxNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights")}; + opGraph.nodes.push_back(softMaxNode); + ATB_SPEED_LOG_DEBUG("Router weights calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseMoeSigmoid(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node sigmoidNode; + sigmoidNode.operation = new atb_speed::common::SigmoidOperation("SparseMoeSigmoidNode"); + sigmoidNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_logits")}; + sigmoidNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights")}; + opGraph.nodes.push_back(sigmoidNode); + ATB_SPEED_LOG_DEBUG("Router logits sigmoid calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateScoreAdd(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node scoreAddNode; + atb::infer::ElewiseParam scoreAddParam; + scoreAddParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(scoreAddParam, &scoreAddNode.operation)); + scoreAddNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights"), + GetTensorIdx(tensorMap, "in_gate_bias")}; + scoreAddNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_for_choice")}; + opGraph.nodes.push_back(scoreAddNode); + ATB_SPEED_LOG_DEBUG("Score add calculated success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseMoetopK( + std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + atb::Node topKNode; + atb::infer::SortParam topKParam; + topKParam.num = param.num; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(topKParam, &topKNode.operation)); + std::string topKInTensorName; + std::vector topKOutTensorNames; + if (param.routingMethod == "noAuxTc") { + topKInTensorName = "intermediate_router_weights_for_choice"; + topKOutTensorNames.push_back("intermediate_router_weights_topk_temp"); + } else { + topKInTensorName = "intermediate_router_weights"; + topKOutTensorNames.push_back("intermediate_router_weights_topk"); + } + topKOutTensorNames.push_back("intermediate_selected_experts"); + topKNode.inTensorIds = {GetTensorIdx(tensorMap, topKInTensorName)}; + topKNode.outTensorIds = GetTensorIdxList(tensorMap, topKOutTensorNames); + opGraph.nodes.push_back(topKNode); + ATB_SPEED_LOG_DEBUG("Expert selection success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseMoetopKGather( + std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node topKGaterNode; + atb::infer::GatherParam topKGatherParam; + topKGatherParam.axis = 1; + topKGatherParam.batchDims = 1; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(topKGatherParam, &topKGaterNode.operation)); + topKGaterNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights"), + GetTensorIdx(tensorMap, "intermediate_selected_experts")}; + topKGaterNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk")}; + opGraph.nodes.push_back(topKGaterNode); + ATB_SPEED_LOG_DEBUG("Expert weight gather success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseMoereduce( + std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node reduceNode; + atb::infer::ReduceParam reduceParam; + reduceParam.reduceType = atb::infer::ReduceParam::ReduceType::REDUCE_SUM; + reduceParam.axis = {1}; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(reduceParam, &reduceNode.operation)); + reduceNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk")}; + reduceNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk_sumed")}; + opGraph.nodes.push_back(reduceNode); + ATB_SPEED_LOG_DEBUG("Reduce sum calculated success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseMoedivide( + std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node divideNode; + atb::infer::ElewiseParam divideParam; + divideParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_REALDIV; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(divideParam, ÷Node.operation)); + divideNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk"), + GetTensorIdx(tensorMap, "intermediate_router_weights_topk_sumed")}; + divideNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk_reduced")}; + divideNode.inTensorReshapeFuncs.resize(divideNode.inTensorIds.size()); + divideNode.inTensorReshapeFuncs[1] = [](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2:number of dimensions of the new shape + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = 1; + }; + opGraph.nodes.push_back(divideNode); + ATB_SPEED_LOG_DEBUG("Router weights calculated success"); + return atb::NO_ERROR; +} + +atb::Status CreateElewiseMuls( + std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + atb::Node mulNode; + atb::infer::ElewiseParam elewiseParam; + elewiseParam.mulsParam.varAttr = param.routedScalingFactor; + elewiseParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MULS; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(elewiseParam, &mulNode.operation)); + std::string mulInTensorName = param.processLogits == "normScaling" ? \ + "intermediate_router_weights_topk_reduced" : "intermediate_router_weights_topk"; + mulNode.inTensorIds = {GetTensorIdx(tensorMap, mulInTensorName)}; + mulNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk_reduced")}; + opGraph.nodes.push_back(mulNode); + ATB_SPEED_LOG_DEBUG("ElewiseMuls calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateDeviceLimitedRouting( + std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + atb::Node deviceLimitedNode; + atb_speed::deviceLimitedRouting::DeviceLimitedRoutingParam deviceLimitedRoutingParam; + deviceLimitedRoutingParam.numOfExperts = param.numOfExperts; + deviceLimitedRoutingParam.numOfGroups = param.numOfGroups; + deviceLimitedRoutingParam.topkGroups = param.topkGroups; + atb_speed::deviceLimitedRouting::CreateDeviceLimitedRoutingOperation(deviceLimitedRoutingParam, + &deviceLimitedNode.operation); + deviceLimitedNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights"), + GetTensorIdx(tensorMap, "in_expert_group"), + GetTensorIdx(tensorMap, "in_one_hot"), + GetTensorIdx(tensorMap, "in_zero_hot")}; + deviceLimitedNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights")}; + opGraph.nodes.push_back(deviceLimitedNode); + ATB_SPEED_LOG_DEBUG("Router logits calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateGroupOperation( + std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + atb::Node deviceLimitedNode; + atb::infer::GroupTopkParam groupedParam; + groupedParam.groupNum = param.numOfGroups; + groupedParam.k = param.topkGroups[0]; + if (param.routingMethod == "noAuxTc") { + groupedParam.groupMultiFlag = atb::infer::GroupTopkParam::GroupMultiFlag::SUM_MULTI_MAX; + groupedParam.n = 2; // 2: chosen number within each group + } + CHECK_OPERATION_STATUS_RETURN(CreateOperation(groupedParam, &deviceLimitedNode.operation)); + std::vector deviceLimitedInTensorNames; + std::string deviceLimitedOutTensorName; + if (param.routingMethod == "noAuxTc") { + deviceLimitedInTensorNames.push_back("intermediate_router_weights_for_choice"); + deviceLimitedOutTensorName = "intermediate_router_weights_for_choice"; + } else { + deviceLimitedInTensorNames.push_back("intermediate_router_weights"); + deviceLimitedOutTensorName = "intermediate_router_weights"; + } + deviceLimitedInTensorNames.push_back("in_expert_group"); + deviceLimitedNode.inTensorIds = GetTensorIdxList(tensorMap, deviceLimitedInTensorNames); + deviceLimitedNode.outTensorIds = {GetTensorIdx(tensorMap, deviceLimitedOutTensorName)}; + opGraph.nodes.push_back(deviceLimitedNode); + ATB_SPEED_LOG_DEBUG("Fusion Router logits calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseTopkSoftMax( + std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + atb::Node topKNode; + atb_speed::common::MoeTopkSoftmaxParam moeTopkSoftmaxParam; + moeTopkSoftmaxParam.topkNum = int64_t(param.num.at(0)); + topKNode.operation = new atb_speed::common::MoeTopkSoftmaxOperation("MoeTopkSoftmaxOperation", moeTopkSoftmaxParam); + topKNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_logits")}; + topKNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk"), + GetTensorIdx(tensorMap, "intermediate_selected_experts"), + GetTensorIdx(tensorMap, "intermediate_router_weights")}; + opGraph.nodes.push_back(topKNode); + ATB_SPEED_LOG_DEBUG("Expert selection success"); + return atb::NO_ERROR; +} + +atb::Status CreateVectorNorm(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node vectorNormNode; + atb_speed::common::AclNNVectorNormParam aclNNVectorNormParam; + vectorNormNode.operation = new atb_speed::common::VectorNormOperation("vectorNormOperation", aclNNVectorNormParam); + vectorNormNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk")}; + vectorNormNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk_sumed")}; + opGraph.nodes.push_back(vectorNormNode); + ATB_SPEED_LOG_DEBUG("execute vector norm success"); + return atb::NO_ERROR; +} + +atb::Status SetDynamicExpertParam(atb_speed::common::DynamicEpMoEParam &dynamicExpertParam, const SparseMoeParam ¶m) +{ + dynamicExpertParam.transpose = param.transpose; + dynamicExpertParam.topk = param.num.at(0); + if (param.mixSharedRouting) { + dynamicExpertParam.topk = dynamicExpertParam.topk + 1; + } + dynamicExpertParam.scaledTopk = param.scaledTopk; + dynamicExpertParam.enableInitRoutingCutoff = param.enableInitRoutingCutoff; + dynamicExpertParam.numOfExperts = param.numOfExperts; + dynamicExpertParam.numOfDeviceExperts = param.numOfDeviceExperts; + dynamicExpertParam.supportSwiGLU = param.supportSwiGLU; + dynamicExpertParam.expertParallelDegree = param.expertParallelDegree; + dynamicExpertParam.isDynamicEp = param.isDynamicEp; + dynamicExpertParam.deviceExpert = param.deviceExpert; + dynamicExpertParam.enableMoeDistribute = param.enableMoeDistribute; + dynamicExpertParam.enableFusedTopk = param.enableFusedTopk; + dynamicExpertParam.moeLinearQuantType = param.moeLinearQuantType; + dynamicExpertParam.packQuantType = param.packQuantType; + dynamicExpertParam.denseQuantType = param.denseQuantType; + dynamicExpertParam.isBF16 = param.isBF16; + dynamicExpertParam.gateUpTransposeB = param.gateUpTransposeB; + dynamicExpertParam.downTransposeB = param.downTransposeB; + dynamicExpertParam.enableFusedRouting = param.enableFusedRouting; + dynamicExpertParam.enableGMMSwigluQuant = param.enableGMMSwigluQuant; + dynamicExpertParam.enableInitQuant = param.enableInitQuant; + dynamicExpertParam.enableSwigluQuant = param.enableSwigluQuant; + dynamicExpertParam.enableAtlasGMMFused = param.enableAtlasGMMFused; + dynamicExpertParam.backend = param.backend; + dynamicExpertParam.hcclComm = param.hcclComm; + dynamicExpertParam.hasMoeEp = param.hasMoeEp; + dynamicExpertParam.moeEpRank = param.moeEpRank; + dynamicExpertParam.moeEpSize = param.moeEpSize; + dynamicExpertParam.moeEpDomain = param.moeEpDomain; + dynamicExpertParam.moeEpRankTableFile = param.moeEpRankTableFile; + dynamicExpertParam.quantGroupSize = param.quantGroupSize; + dynamicExpertParam.enableCVOverlap = param.enableCVOverlap; + dynamicExpertParam.routingMethod = param.routingMethod; + dynamicExpertParam.maxDecodeDpTokenSize = param.maxDecodeDpTokenSize; + if (param.enableEPWB) { + dynamicExpertParam.numOfExperts = param.numOfExperts + param.numOfRedundantExpert; + } + dynamicExpertParam.numDanglingSharedExperts = param.numDanglingSharedExperts; + dynamicExpertParam.numOfRedundantExpert = param.numOfRedundantExpert; + dynamicExpertParam.enableExpertCumSumOutput = param.enableExpertCumSumOutput; + dynamicExpertParam.enableGatingDp = param.enableGatingDp; + dynamicExpertParam.enableDispatchCombineV2 = param.enableDispatchCombineV2; + dynamicExpertParam.enableLcocAll2All = param.enableLcocAll2All; + dynamicExpertParam.lcclMoeEpDomain = param.lcclMoeEpDomain; + dynamicExpertParam.lcclMoeEpHcclComm = param.lcclMoeEpHcclComm; + dynamicExpertParam.mixSharedRouting = param.mixSharedRouting; + return atb::NO_ERROR; +} + +atb::Status CreateFp32Cast(std::map &tensorMap, atb::GraphParam &opGraph) +{ + atb::Node castNode; + atb::infer::ElewiseParam castParam; + castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castParam.outTensorType = ACL_FLOAT; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castParam, &castNode.operation)); + castNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk_reduced")}; + castNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk_reduced_fp32")}; + opGraph.nodes.push_back(castNode); + ATB_SPEED_LOG_DEBUG("Cast calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateFp16Cast(std::map &tensorMap, + const SparseMoeParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node castNode; + atb::infer::ElewiseParam castParam; + castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST; + castParam.outTensorType = param.isBF16 ? ACL_BF16 : ACL_FLOAT16; + CHECK_OPERATION_STATUS_RETURN(CreateOperation(castParam, &castNode.operation)); + if (param.mixSharedRouting) { + castNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk_reduced_mix_shared")}; + } else { + castNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk_reduced_fp32")}; + } + castNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_router_weights_topk_reduced_fp16")}; + if (param.enableGatingOverlap) { + atb::SetExecuteStreamId(castNode.operation, STREAM1); + } + opGraph.nodes.push_back(castNode); + ATB_SPEED_LOG_DEBUG("Cast calculation success"); + return atb::NO_ERROR; +} + +atb::Status SetExpertRoutingMap( + std::map &tensorMap, + const SparseMoeParam ¶m, + atb::GraphParam &opGraph) +{ + atb::Node padNode; + atb::infer::GatherParam padParam; + atb::CreateOperation(padParam, &padNode.operation); + if (param.mixSharedRouting) { + padNode.inTensorIds = atb_speed::common::GetTensorIdxList( + tensorMap, { "in_expert_routing_map", "intermediate_selected_experts_mix_shared"}); + } else { + padNode.inTensorIds = atb_speed::common::GetTensorIdxList( + tensorMap, { "in_expert_routing_map", "intermediate_selected_experts"}); + } + padNode.outTensorIds = atb_speed::common::GetTensorIdxList( + tensorMap, {"intermediate_selected_experts_routed"}); + + padNode.inTensorReshapeFuncs.resize(padNode.inTensorIds.size()); + padNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; // 2: dimNum + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + padNode.inTensorReshapeFuncs[1] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 1; // 2: dimNum + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; + }; + opGraph.nodes.push_back(padNode); + ATB_SPEED_LOG_DEBUG("create padNode"); + return atb::NO_ERROR; +} + +std::list GetOutTensorName(const SparseMoeParam ¶m) +{ + std::list nameList; + nameList.push_back("in_hiddenstates"); + nameList.push_back("in_mlp_gateup_weight_expert"); + nameList.push_back("in_mlp_gateup_bias_expert"); + nameList.push_back("in_mlp_gateup_descale_expert"); + nameList.push_back("in_mlp_gateup_offset_expert"); + nameList.push_back("in_mlp_gateup_scale_expert"); + nameList.push_back("in_mlp_gateup_compress_idx_expert"); + nameList.push_back("in_mlp_down_weight_expert"); + nameList.push_back("in_mlp_down_bias_expert"); + nameList.push_back("in_mlp_down_descale_expert"); + nameList.push_back("in_mlp_down_offset_expert"); + nameList.push_back("in_mlp_down_scale_expert"); + nameList.push_back("in_mlp_down_compress_idx_expert"); + nameList.push_back("in_expert_array"); + if (!param.enableGatingDp) { + nameList.push_back(param.enableLoadBalance ? "in_fake_topk" : + (param.enableEPWB ? "intermediate_selected_experts_routed" : "intermediate_selected_experts")); + + if (param.processLogits != "none") { + std::string routerWeightsTopkReducedName; + if (param.enableMoeDistribute) { + if (param.mixSharedRouting) { + routerWeightsTopkReducedName = "intermediate_router_weights_topk_reduced_mix_shared"; + } else { + routerWeightsTopkReducedName = "intermediate_router_weights_topk_reduced_fp32"; + } + } else if (!param.enableMoeDistribute) { + routerWeightsTopkReducedName = "intermediate_router_weights_topk_reduced_fp16"; + } else { + if (param.mixSharedRouting) { + routerWeightsTopkReducedName = "intermediate_router_weights_topk_reduced_mix_shared"; + } else { + routerWeightsTopkReducedName = "intermediate_router_weights_topk_reduced"; + } + } + nameList.push_back(routerWeightsTopkReducedName); + } else { + nameList.push_back("intermediate_router_weights_topk"); + } + } else { + nameList.push_back("intermediate_selected_experts_all"); + nameList.push_back("intermediate_router_weights_topk_reduced_all"); + } + nameList.push_back("in_one_hot"); + nameList.push_back("in_zero_hot"); + if (param.hasMoeEp) { + nameList.push_back("in_start_expert_idx"); + nameList.push_back("in_device_expert_count"); + nameList.push_back("in_padding_idx"); + if (param.isDynamicEp) { + nameList.push_back("in_dynamic_ep_idx"); + nameList.push_back("in_moe_idx"); + } + } + return nameList; +} + +atb::Status GetoutTensorIdx( + std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + if (param.enableEPWB) { + if (param.mixSharedRouting) { + CHECK_OPERATION_STATUS_RETURN(CreateConcatExpertOperation(tensorMap, opGraph)); + } + SetExpertRoutingMap(tensorMap, param, opGraph); + } + atb::Node expertNode; + atb_speed::common::DynamicEpMoEParam dynamicExpertParam; + SetDynamicExpertParam(dynamicExpertParam, param); + atb_speed::common::CreateDynamicEpMoEOperation(dynamicExpertParam, &expertNode.operation); + expertNode.outTensorIds = {GetTensorIdx(tensorMap, "out_moe_rout")}; + if (param.enableExpertCumSumOutput) { + expertNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "out_gmm_cumsum_list")); + } + std::list nameList = GetOutTensorName(param); + for (auto iter = nameList.cbegin(); iter != nameList.cend(); ++iter) { + expertNode.inTensorIds.push_back(GetTensorIdx(tensorMap, *iter)); + } + if (param.enableEPWB) { + expertNode.inTensorReshapeFuncs.resize(expertNode.inTensorIds.size()); + if (!param.mixSharedRouting) { + expertNode.inTensorReshapeFuncs[14] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2: dimNum + newShape.dims[0] = oldShape.dims[0] / param.num.at(0); + newShape.dims[1] = param.num.at(0); + }; + } else { + expertNode.inTensorReshapeFuncs[14] = [=](const atb::Dims &oldShape, atb::Dims &newShape) { // 14: topk + newShape.dimNum = 2; // 2: dimNum + newShape.dims[0] = oldShape.dims[0] / (param.num.at(0) + 1); + newShape.dims[1] = param.num.at(0) + 1; + }; + } + } + opGraph.nodes.push_back(expertNode); + ATB_SPEED_LOG_DEBUG("Expert Group calculation success5"); + return atb::NO_ERROR; +} + +atb::Status RoutingBlock( + std::map &tensorMap, + const SparseMoeParam ¶m, atb::GraphParam &opGraph) +{ + if (param.enableFusedTopk) { + CHECK_OPERATION_STATUS_RETURN(CreateFusedAddTopkDiv(tensorMap, param, opGraph)); + if (param.mixSharedRouting) { + CHECK_OPERATION_STATUS_RETURN(CreateConcatWeightOperation(tensorMap, opGraph)); + } + } else { + if (param.routingMethod == "deviceLimited") { + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoesoftMax(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateGroupOperation(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoetopK(tensorMap, param, opGraph)); + } else if (param.routingMethod == "integratedSoftmaxTopK") { + CHECK_OPERATION_STATUS_RETURN(CreateSparseTopkSoftMax(tensorMap, param, opGraph)); + } else if (param.routingMethod == "noAuxTc") { + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoeSigmoid(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateScoreAdd(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateGroupOperation(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoetopK(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoetopKGather(tensorMap, opGraph)); + } else { + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoesoftMax(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoetopK(tensorMap, param, opGraph)); + } + } + ATB_SPEED_LOG_DEBUG("Routing Block success"); + return atb::NO_ERROR; +} + +atb::Status CreateConcatWeightOperation( + std::map &tensorMap, + atb::GraphParam &opGraph) +{ + ATB_SPEED_LOG_DEBUG("AclNNConcat weight start"); + atb::Node aclnnConcatNode; + atb_speed::common::AclNNConcatParam aclNNConcatParam; + aclNNConcatParam.dim = 1; + std::vector concatInTensorNames; + std::string concatOutTensorName; + concatInTensorNames.push_back("intermediate_router_weights_topk_reduced_fp32"); + concatInTensorNames.push_back("mix_shared_routing_weight"); + concatOutTensorName = "intermediate_router_weights_topk_reduced_mix_shared"; + aclnnConcatNode.operation = new atb_speed::common::ConcatOperation("concatWeight", aclNNConcatParam); + aclnnConcatNode.inTensorIds = GetTensorIdxList(tensorMap, concatInTensorNames); + aclnnConcatNode.outTensorIds = {GetTensorIdx(tensorMap, concatOutTensorName)}; + opGraph.nodes.push_back(aclnnConcatNode); + ATB_SPEED_LOG_DEBUG("AclNNConcat weight success"); + return atb::NO_ERROR; +} + +atb::Status CreateConcatExpertOperation( + std::map &tensorMap, + atb::GraphParam &opGraph) +{ + ATB_SPEED_LOG_DEBUG("AclNNConcat expert start"); + atb::Node aclnnConcatNode; + atb_speed::common::AclNNConcatParam aclNNConcatParam; + aclNNConcatParam.dim = 1; + std::vector concatInTensorNames; + std::string concatOutTensorName; + concatInTensorNames.push_back("intermediate_selected_experts"); + concatInTensorNames.push_back("mix_shared_routing_expert"); + concatOutTensorName = "intermediate_selected_experts_mix_shared"; + aclnnConcatNode.operation = new atb_speed::common::ConcatOperation("concatExpert", aclNNConcatParam); + aclnnConcatNode.inTensorIds = GetTensorIdxList(tensorMap, concatInTensorNames); + aclnnConcatNode.outTensorIds = {GetTensorIdx(tensorMap, concatOutTensorName)}; + opGraph.nodes.push_back(aclnnConcatNode); + ATB_SPEED_LOG_DEBUG("AclNNConcat expert success"); + return atb::NO_ERROR; +} + +atb::Status CreateRecord(const SparseMoeParam ¶m, atb::GraphParam &opGraph, + atb_speed::EventAction eventAction, const std::string &cvKey) +{ + if (param.enableCVOverlap || param.enableGatingOverlap) { + atb::Node recordNode; + recordNode.inTensorIds = {}; + recordNode.outTensorIds = {}; + CHECK_OPERATION_STATUS_RETURN(atb_speed::EventManager::GetInstance().RecordEvent( + recordNode.operation, + eventAction, + cvKey)); + if (param.enableGatingOverlap) { + atb::SetExecuteStreamId(recordNode.operation, STREAM1); + } + opGraph.nodes.push_back(recordNode); + ATB_SPEED_LOG_DEBUG("Record event success"); + } + return atb::NO_ERROR; +} + +atb::Status CreateWait(const SparseMoeParam ¶m, atb::GraphParam &opGraph, + atb_speed::EventAction eventAction, const std::string &cvKey) +{ + if (param.enableCVOverlap || param.enableGatingOverlap) { + atb::Node waitNode; + waitNode.inTensorIds = {}; + waitNode.outTensorIds = {}; + CHECK_OPERATION_STATUS_RETURN(atb_speed::EventManager::GetInstance().WaitEvent( + waitNode.operation, + eventAction, + cvKey)); + if (param.enableGatingOverlap) { + atb::SetExecuteStreamId(waitNode.operation, STREAM1); + } + opGraph.nodes.push_back(waitNode); + ATB_SPEED_LOG_DEBUG("Wait event success"); + } + return atb::NO_ERROR; +} + +atb::Status SetTPAllGatherNode(std::map tensorMap, const SparseMoeParam ¶m, + atb::GraphParam &opGraph, bool isTopk) +{ + atb::Node allGatherNode; + atb::infer::AllGatherParam allGatherParam; + allGatherParam.rank = param.mlpTpRank; + allGatherParam.rankSize = param.mlpTpSize; + allGatherParam.backend = param.mlpTpBackend; + allGatherParam.commDomain = param.mlpTpDomain; + allGatherParam.rankTableFile = param.mlpTpRankTableFile; + allGatherParam.hcclComm = param.hcclTpComm; + + CreateOperation(allGatherParam, &allGatherNode.operation); + + allGatherNode.inTensorIds = atb_speed::common::GetTensorIdxList(tensorMap, + {(isTopk) ? "intermediate_selected_experts" : (param.enableFusedTopk) ? \ + "intermediate_router_weights_topk_reduced_fp16" : "intermediate_router_weights_topk_reduced"}); + allGatherNode.outTensorIds = atb_speed::common::GetTensorIdxList(tensorMap, + {(isTopk) ? "intermediate_selected_experts_with_padding_all" : \ + "intermediate_router_weights_topk_reduced_with_padding_all"}); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsBeforeComm(opGraph)); + opGraph.nodes.push_back(allGatherNode); + CHECK_OPERATION_STATUS_RETURN(common::AddDapEventsAfterComm(opGraph)); + return atb::NO_ERROR; +} + +atb::Status SetTPUnPadding(std::map &tensorMap, atb::GraphParam &opGraph, bool isTopk) +{ + atb::Node unpadNode; + atb::infer::GatherParam unpadParam; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(unpadParam, &unpadNode.operation)); + unpadNode.inTensorIds = atb_speed::common::GetTensorIdxList( + tensorMap, {(isTopk) ? "intermediate_selected_experts_with_padding_all" : \ + "intermediate_router_weights_topk_reduced_with_padding_all", "in_attn_unpadding_idx"}); + unpadNode.outTensorIds = atb_speed::common::GetTensorIdxList(tensorMap, + {(isTopk) ? "intermediate_selected_experts_all" : "intermediate_router_weights_topk_reduced_all"}); + unpadNode.inTensorReshapeFuncs.reserve(unpadNode.inTensorIds.size()); + unpadNode.inTensorReshapeFuncs.resize(unpadNode.inTensorIds.size()); + unpadNode.inTensorReshapeFuncs[0] = [=] (const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2:新shape维度为2 + if (oldShape.dimNum == 3) { // 3:旧shape维度为3 + newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1]; // 0, 0, 1: 新shape前两维合轴 + newShape.dims[1] = oldShape.dims[2]; // 1, 2: 新shape最后一维不变 + } else { + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = oldShape.dims[1]; // 1, 2: 新shape最后一维不变 + } + }; + opGraph.nodes.push_back(unpadNode); + ATB_SPEED_LOG_DEBUG("AllGather calculation success"); + return atb::NO_ERROR; +} + +atb::Status CreateSparseMoeOperation(const SparseMoeParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + opGraph.name = "SparseMoe"; + std::map tensorMap = ConstructTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + ATB_SPEED_LOG_DEBUG("opGraph.inTensorNum " << opGraph.inTensorNum); + ATB_SPEED_LOG_DEBUG("opGraph.outTensorNum " << opGraph.outTensorNum); + ATB_SPEED_LOG_DEBUG("opGraph.internalTensorNum" << opGraph.internalTensorNum); + + if (param.enableCVOverlap) { + CHECK_OPERATION_STATUS_RETURN(CreateRecord( + param, opGraph, atb_speed::EventAction::POP, atb_speed::common::CV_START)); + } + if (param.routingMethod == "noAuxTc") { + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoemoeGateFp32(tensorMap, param, opGraph)); + } else { + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoemoeGate(tensorMap, param, opGraph)); + } + if (param.enableCVOverlap) { + CHECK_OPERATION_STATUS_RETURN(CreateWait( + param, opGraph, atb_speed::EventAction::POP, atb_speed::common::VECTOR_CONTROL)); + CHECK_OPERATION_STATUS_RETURN(CreateRecord( + param, opGraph, atb_speed::EventAction::POP, atb_speed::common::CUBE_CONTROL)); + } + if (param.useStdNorm) { + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoeStd(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoeNorm(tensorMap, opGraph)); + } + if (param.enableGatingShift) { + CHECK_OPERATION_STATUS_RETURN(CreateSplit(tensorMap, param, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateConcat(tensorMap, param, opGraph)); + } + CHECK_OPERATION_STATUS_RETURN(RoutingBlock(tensorMap, param, opGraph)); + + if (!param.enableFusedTopk && param.processLogits != "none") { + if (param.processLogits == "normalization") { + // In_tensor[0]: router_weights: Batch * Seq; 2 + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoereduce(tensorMap, opGraph)); + // In_tensor[0]: router_weights: Batch * Seq; 2 + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoedivide(tensorMap, opGraph)); + } else if (param.processLogits == "scaling") { + CHECK_OPERATION_STATUS_RETURN(CreateElewiseMuls(tensorMap, param, opGraph)); + } else if (param.processLogits == "norm") { + CHECK_OPERATION_STATUS_RETURN(CreateVectorNorm(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoedivide(tensorMap, opGraph)); + } else if (param.processLogits == "normScaling") { + // In_tensor[0]: router_weights: Batch * Seq; 2 + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoereduce(tensorMap, opGraph)); + // In_tensor[0]: router_weights: Batch * Seq; 2 + CHECK_OPERATION_STATUS_RETURN(CreateSparseMoedivide(tensorMap, opGraph)); + CHECK_OPERATION_STATUS_RETURN(CreateElewiseMuls(tensorMap, param, opGraph)); + } + CHECK_OPERATION_STATUS_RETURN(CreateFp32Cast(tensorMap, opGraph)); + if (param.mixSharedRouting) { + CHECK_OPERATION_STATUS_RETURN(CreateConcatWeightOperation(tensorMap, opGraph)); + } + } + if (!param.enableMoeDistribute && (param.processLogits != "none")) { + CHECK_OPERATION_STATUS_RETURN(CreateFp16Cast(tensorMap, param, opGraph)); + } + if (param.enableGatingOverlap) { + CHECK_OPERATION_STATUS_RETURN(CreateRecord( + param, opGraph, atb_speed::EventAction::POP, atb_speed::common::COMP_CONTROL)); + CHECK_OPERATION_STATUS_RETURN(CreateWait( + param, opGraph, atb_speed::EventAction::POP, atb_speed::common::COMM_CONTROL)); + } + + if (param.enableGatingDp) { + CHECK_OPERATION_STATUS_RETURN(SetTPAllGatherNode(tensorMap, param, opGraph, true)); + CHECK_OPERATION_STATUS_RETURN(SetTPUnPadding(tensorMap, opGraph, true)); + CHECK_OPERATION_STATUS_RETURN(SetTPAllGatherNode(tensorMap, param, opGraph, false)); + CHECK_OPERATION_STATUS_RETURN(SetTPUnPadding(tensorMap, opGraph, false)); + } + + CHECK_OPERATION_STATUS_RETURN(GetoutTensorIdx(tensorMap, param, opGraph)); + + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + if (param.enableExpertCumSumOutput) { + outTensorDescs.at(1) = atb::TensorDesc{}; + outTensorDescs.at(1).format = ACL_FORMAT_ND; + outTensorDescs.at(1).shape.dimNum = 1; + outTensorDescs.at(1).dtype = ACL_INT64; + outTensorDescs.at(1).shape.dims[0] = param.numOfDeviceExperts; + } + + return atb::NO_ERROR; + }; + + return atb::CreateOperation(opGraph, operation); +} +} +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/moe/sparse_moe.h b/tests/proftest/layer_test_framework/operations/fusion/moe/sparse_moe.h new file mode 100644 index 00000000..dcc9d55e --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/moe/sparse_moe.h @@ -0,0 +1,161 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +* +* * Licensed under the Apache License, Version 2.0 (the "License"); +* * you may not use this file except in compliance with the License. +* * You may obtain a copy of the License at +* * +* * http://www.apache.org/licenses/LICENSE-2.0 +* * +* * Unless required by applicable law or agreed to in writing, software +* * distributed under the License is distributed on an "AS IS" BASIS, +* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* * See the License for the specific language governing permissions and +* * limitations under the License. +* */ +#ifndef ATB_SPEED_MODELS_SPARSE_MOE_OPERATION_H +#define ATB_SPEED_MODELS_SPARSE_MOE_OPERATION_H +#include +#include +#include +#include "atb_speed/log.h" +#include "atb_speed/utils/operation_util.h" +#include "operations/fusion/utils.h" +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/linear/linear_parallel.h" +#include "operations/fusion/norm/norm_linear.h" + +namespace atb_speed { +namespace common { + +enum SparseMoeIdx : int { + ROUTER_IDX = 0, + MOE_MLP_GATE_IDX, + MOE_MLP_UP_IDX, + MOE_MLP_DOWN_IDX +}; + +struct SparseMoeParam { + atb::SVector axes = {1}; /// The axes on which softmax is applied + atb::SVector num = {6}; /// The number of experts selected for each token + atb::SVector topkGroups = {3}; /// The number of groups/device selected + int scaledTopk = -1; /// The non-deepseek models do not have the scaledTopk feature enabled by default + bool enableInitRoutingCutoff = false; /// A flag indicating whether to use scaled topk option + std::vector moeLinearQuantType = {}; /// The list of quantization types of linear operations in MoE graph + std::vector deviceExpert = {}; /// The list of experts loaded on the device + uint32_t numOfDeviceExperts = 64; /// The number of experts loaded to the device + uint32_t numOfExperts = 64; /// The total number of experts utilized by the model + int numOfGroups = 8; /// number of groups in total + int expertParallelDegree = 0; /// The specific realization of expert parallelism strategy utilized by the model + float routedScalingFactor = 1.0; /// The optional scaling factor for expert scores + bool transpose = true; /// A flag indicating whether matrecies need to be transpose for matrix multiplications + bool supportSwiGLU = true; /// A flag indicating whether the device supports SwiGlu operator + bool isBF16 = false; /// A flag indicating whether the model runs on bfloat16 + bool isDynamicEp = false; /// A flag indicating whether to use dynamic expert parallelism mechanism + std::string routingMethod = "softMaxTopK"; /// The way in which the top k experts are selected + std::string processLogits = "none"; /// The way in which expert scores are further processed + bool gateUpTransposeB = false; /// A flag indicating whether the B matrix of gateup operation should be transposed + bool downTransposeB = false; /// A flag indicating whether the B matrix of down operation should be transposed + bool enableFusedRouting = false; /// A flag indicating whether or not to use integrated routing operators + bool enableInitQuant = false; /// A flag indicating whether to use routing-quant integrated operator + bool enableSwigluQuant = false; /// A flag indicating whether to use swiglu-quant integrated operator + bool enableMoeParallel = false; /// A flag indicating whether the model use Moe parallel + bool enableCVOverlap = false; /// A flag indicating whether the model use cube and vector parallel + bool enableFusedTopk = false; /// A flag indicating whether to use fused topk operator + bool rounterHasBias = false; /// A flag indicating whether is bias in the expert selection process + /// A flag indicating whether or not to use integrated GMM+Swiglu+quant operators. + bool enableGMMSwigluQuant = false; + /// A flag indicating whether or not to use fused atb GMM+Swiglu+quant operators instead of aclnn. + bool enableAtlasGMMFused = false; + int packQuantType = atb_speed::common::PackQuantType::ALL_FP; // The quantization type of the packed weights + int quantGroupSize = 0; /// Group size of per-group quantization + /// The quantization type used to facilitate the calculation of the quantization type of the linear operation + int denseQuantType = atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED; + bool useStdNorm = false; /// A flag indicating whether the model utilizes std to normalize expert scores + bool enableDispatchCombineV2 = false; /// A flag indicating whether to use dispatch_v2 and combine_v2 + + std::string backend = "hccl"; /// The communication used in the backend + std::string mlpTpBackend = "hccl"; /// The communication used in the mlp tp backend + bool hasMoeEp = false; /// A flag indicating whether the model utilizes expert parallelism + int moeEpRank = 0; /// The rank of this device in the expert parallelism communication domain + int moeEpSize = 1; /// The size of the expert parallelism communication domain + int maxDecodeDpTokenSize = 0; + std::string moeEpDomain = ""; /// The communication domain of expert parallelism + std::string moeEpRankTableFile = ""; /// The rankTableFile for teh device in the communication domain + bool hasMlpTp = false; /// A flag indicating whether FFN utilizes tensor parallelism + int mlpTpRank = 0; /// The rank of this device in the tensor parallelism communication domain of FFN + int mlpTpSize = 1; /// The size of the tensor parallelism communication domain of FFN + std::string mlpTpDomain = ""; /// The communication domain of tensor parallelism of FFN + std::string mlpTpRankTableFile = ""; /// The rankTableFile for the device in the communication domain + bool enableMoeDistribute = false; /// A flag indicating whether to use moe distribute fusion operator + bool enableExpertCumSumOutput = false; /// A flag indicating whether output ExpertCumSum + bool enableGatingDp = false; /// A flag indicating whether gate dp + bool enableGatingShift = false; /// A flag indicating whether gate need shift + bool enableGatingOverlap = false; /// A flag indicating whether Gating overlap + bool enableFp32GateInput = false; + int64_t numDanglingSharedExperts = 0; + + bool enableATBGateMatmul = false; /// A flag indicating whether enable ATB GateMatmul + bool enableLoadBalance = false; + bool enableEPWB = false; + uint32_t numOfRedundantExpert = 0; + bool hasBias = false; + HcclComm hcclComm = nullptr; + HcclComm hcclTpComm = nullptr; + + bool enableLcocAll2All = false; + std::string lcclMoeEpDomain = ""; + HcclComm lcclMoeEpHcclComm = nullptr; + + bool mixSharedRouting = false; +}; + +/// This function creates the graph of the MoE of a model. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateSparseMoeOperation(const SparseMoeParam ¶m, atb::Operation **operation); +/// This function adds a linear transformation operator that calculates the row score of each expert on each token. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateSparseMoemoeGate( + const SparseMoeParam ¶m, atb::Node &linearNode, atb::GraphParam opGraph); +/// This function adds a linear transformation operator that calculates the row score of each expert on each token +/// in float32 dtype. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateSparseMoemoeGateFp32( + std::map &tensorMap, atb::GraphParam &opGraph); +/// This function adds a softmax operator that process the score of each expert on each token to the graph. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateSparseMoesoftMax( + const SparseMoeParam ¶m, atb::Node &softMaxNode, atb::GraphParam opGraph); +/// This function adds a sorting operator that selects top experts for each token to the graph. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateSparsMoetopK( + const SparseMoeParam ¶m, atb::Node &topKNode, atb::GraphParam opGraph); +/// This function, working along with `CreateSparseMoedivide`, normalizes the scores of top experts. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateSparseMoereduce(atb::Node &reduceNode, atb::GraphParam opGraph); +/// This function, working along with `CreateSparseMoereduce`, normalizes the scores of top experts. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateSparseMoedivide( + std::shared_ptr batchDimPtr, atb::Node ÷Node, atb::GraphParam opGraph); +/// This funciton adds an std operator that normalizes expert scores to the graph. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateSparseMoeStd( + std::map &tensorMap, atb::GraphParam &opGraph); +/// This function, working along with `CreateSparseMoeStd`, normalizes the scores of top experts. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateSparseMoeNorm( + std::map &tensorMap, atb::GraphParam &opGraph); +/// This function, working along with `CreateFusedAddTopkDiv`, concat extra expert for top experts. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateConcatExpertOperation( + std::map &tensorMap, + atb::GraphParam &opGraph); +/// This function, working along with `CreateConcatExpertOperation`, concat extra weight for top experts. +/// \return A flag that indicates whether the opertaion is successfully created or not. +atb::Status CreateConcatWeightOperation( + std::map &tensorMap, + atb::GraphParam &opGraph); + +} +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/norm/norm_linear.cpp b/tests/proftest/layer_test_framework/operations/fusion/norm/norm_linear.cpp new file mode 100644 index 00000000..5f368fa2 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/norm/norm_linear.cpp @@ -0,0 +1,425 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/log.h" +#include "operations/aclnn/ops/add_rms_norm_quant_operation.h" +#include "operations/aclnn/ops/add_rms_norm_dynamic_quant_operation.h" +#include "operations/aclnn/ops/obfuscation_calculate_operation.h" +#include "operations/aclnn/utils/utils.h" +#include "operations/fusion/utils.h" +#include "operations/fusion/norm/norm_linear.h" + +namespace atb_speed { +namespace common { + +template +bool UseNormQuant(const NormLinearParam ¶m) +{ + if (param.fusionLinearParam.quantType == LinearQuantType::LINEAR_W8A8_DEQUANT || \ + param.fusionLinearParam.quantType == LinearQuantType::LINEAR_W8A8_SC_DEQUANT || \ + (param.fusionLinearParam.quantType == LinearQuantType::LINEAR_W8A8_DYNAMIC_DEQUANT && \ + param.enableAddNorm) || \ + (param.fusionLinearParam.quantType == LinearQuantType::LINEAR_W4A8_DYNAMIC_DEQUANT && \ + param.enableAddNorm)) { + return true; + } else { + return false; + } +} + +template +bool UseAddRmsNormQuant(const NormLinearParam ¶m) +{ + // 算子支持310P、A2、A3机型 + return (Is310P() || IsA2()) && param.enableAddNorm && + param.fusionLinearParam.quantType == LinearQuantType::LINEAR_W8A8_DEQUANT; +} + +template +bool UseAddRmsNormDynamicQuant(const NormLinearParam ¶m) +{ + // 算子支持A2、A3机型 + return IsA2() && param.enableAddNorm && + (param.fusionLinearParam.quantType == LinearQuantType::LINEAR_W8A8_DYNAMIC_DEQUANT \ + || param.fusionLinearParam.quantType == LinearQuantType::LINEAR_W4A8_DYNAMIC_DEQUANT); +} + +std::map> GetInTensorCandidates() +{ + std::map> normInTensorCandidates = { + {"default", { + "in_input", "in_norm_weight", "in_norm_bias", "in_norm_new_weight", "in_norm_new_bias", + "in_linear_weight", "in_scale", "in_offset", "in_descale", "in_bias", "in_compress_idx"} + }, + {"add_norm", {"in_residual_input"}}, + {"add_rmsnorm_quant", {"in_scale_fill", "in_offset_fill"}}, + {"lora", {"in_seq_len_cum_sum", "in_linear_lora_a", "in_linear_lora_b"}}, + {"lora_with_mask", {"in_im_mask"}}, + {"flash_comm", { + "send_counts", "sdispls", "send_count", "recv_counts", "rdispls", "recv_count", "fake_ag_shape"} + }, + }; + return normInTensorCandidates; +} + +std::map> GetIntermediateTensorCandidates() +{ + std::map> normIntermediateTensorCandidates = { + {"default", {"intermediate_norm"}}, + {"addrmsnormquant", {"y2_out"}}, + {"addrmsnormdynamicquant", {"y2_out", "scale1_out", "scale2_out"}}, + {"pmcc", {"intermediate_pmcc"}}, + {"pmcc_multi", {"intermediate_pmcc", "intermediate_norm_per_rank", + "intermediate_pmcc_gather", "intermediate_pmcc_gather_out"}}, + }; + return normIntermediateTensorCandidates; +} + +std::map> GetOutTensorCandidates() +{ + std::map> normOutTensorCandidates = { + {"default", {"out_linear"}}, + {"add_norm", {"out_add"}}, + }; + return normOutTensorCandidates; +} + +template +std::map ConstructNormTensorMap( + const NormLinearParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum) +{ + auto normInTensorCandidates = GetInTensorCandidates(); + auto normIntermediateTensorCandidates = GetIntermediateTensorCandidates(); + auto normOutTensorCandidates = GetOutTensorCandidates(); + + std::vector inTensorList = {}; + std::vector intermediateTensorList = {}; + std::vector outTensorList = {}; + + // 添加默认的Tensor + AddTensorToList(normInTensorCandidates, "default", inTensorList); + if (!param.skipNorm) { + AddTensorToList(normIntermediateTensorCandidates, "default", intermediateTensorList); + } + + // 添加add norm特性的Tensor、添加AddRmsNormQuant特性的Tensor + if (param.enableAddNorm) { + AddTensorToList(normInTensorCandidates, "add_rmsnorm_quant", inTensorList); + AddTensorToList(normInTensorCandidates, "add_norm", inTensorList); + } + + // 添加lora特性的Tensor + if (param.fusionLinearParam.supportLora) { + if (param.fusionLinearParam.useImMask) { + AddTensorToList(normInTensorCandidates, "lora_with_mask", inTensorList); + } + AddTensorToList(normInTensorCandidates, "lora", inTensorList); + } + // Add Flashcomm 1.0 Tensor + if (param.fusionLinearParam.enableFlashComm) { + AddTensorToList(normInTensorCandidates, "flash_comm", inTensorList); + } + + // 添加outTensor + AddTensorToList(normOutTensorCandidates, "default", outTensorList); + if (param.enableAddNorm) { + AddTensorToList(normOutTensorCandidates, "add_norm", outTensorList); + } + if (UseAddRmsNormQuant(param)) { + AddTensorToList(normIntermediateTensorCandidates, "addrmsnormquant", intermediateTensorList); + } + if (UseAddRmsNormDynamicQuant(param)) { + AddTensorToList(normIntermediateTensorCandidates, "addrmsnormdynamicquant", intermediateTensorList); + } + if (param.enableModelConfuscation) { + if (param.modelObfuscationParallelInfo.worldSize > 1) { + AddTensorToList(normIntermediateTensorCandidates, "pmcc_multi", intermediateTensorList); + } else { + AddTensorToList(normIntermediateTensorCandidates, "pmcc", intermediateTensorList); + } + } + inTensorNum = inTensorList.size(); + outTensorNum = outTensorList.size(); + internalTensorNum = intermediateTensorList.size(); + + return GetTensorMap(inTensorList, outTensorList, intermediateTensorList); +} + +template +int64_t InsertNorm( + atb::GraphParam &opGraph, + const NormLinearParam ¶m, + std::map &tensorMap) +{ + bool useNormQuant = UseNormQuant(param); + bool useAddRmsNormQuant = UseAddRmsNormQuant(param); + bool useAddRmsNormDynamicQuant = UseAddRmsNormDynamicQuant(param); + atb::Node normNode; + if (param.enableAddNorm) { + normNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_residual_input")); + } + normNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_norm")}; + if (useAddRmsNormQuant || useAddRmsNormDynamicQuant) { + normNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "y2_out")); + normNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "out_add")); + if (useAddRmsNormDynamicQuant) { + normNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "scale1_out")); + normNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "scale2_out")); + } + } else if (param.enableAddNorm) { + normNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "out_add")); + } + if (useNormQuant) { // activation quant + normNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_input")); + normNode.inTensorIds.push_back(param.isAntiOutlier ? \ + GetTensorIdx(tensorMap, "in_norm_new_weight") : GetTensorIdx(tensorMap, "in_norm_weight")); + if (!useAddRmsNormQuant && !useAddRmsNormDynamicQuant) { + normNode.inTensorIds.push_back(param.isAntiOutlier ? \ + GetTensorIdx(tensorMap, "in_norm_new_bias") : GetTensorIdx(tensorMap, "in_norm_bias")); + } + if (useAddRmsNormQuant) { + normNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_scale_fill")); + normNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_offset_fill")); + } else if (!useAddRmsNormDynamicQuant) { + normNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_scale")); + normNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_offset")); + } + if (useAddRmsNormQuant) { // aclnn接口 + normNode.operation = new atb_speed::common::AddRmsNormQuantOperation( + "AddRmsNormQuantOperation", param.normParamType.normParam.epsilon); + } else if (useAddRmsNormDynamicQuant) { // aclnn接口 + normNode.operation = new atb_speed::common::AddRmsNormDynamicQuantOperation( + "AddRmsNormDynamicQuantOperation", param.normParamType.normParam.epsilon); + } else { // ATB接口 + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(param.normQuantParamType, &normNode.operation)); + } + } else { // activation no quant + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(param.normParamType, &normNode.operation)); + normNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "in_input")); + normNode.inTensorIds.push_back(param.isAntiOutlier ? \ + GetTensorIdx(tensorMap, "in_norm_new_weight") : GetTensorIdx(tensorMap, "in_norm_weight")); + if (param.normHasBias) { + normNode.inTensorIds.push_back(param.isAntiOutlier ? \ + GetTensorIdx(tensorMap, "in_norm_new_bias") : GetTensorIdx(tensorMap, "in_norm_bias")); + } + } + opGraph.nodes.push_back(normNode); + return atb::NO_ERROR; +} + +template +atb::Status InsertObfuscationCalculate(atb::GraphParam &opGraph, const NormLinearParam ¶m, + std::map &tensorMap) +{ + bool isMultiRank = param.modelObfuscationParallelInfo.worldSize > 1; + int rank = param.modelObfuscationParallelInfo.rank; + if (isMultiRank) { + atb::Node sliceNode; + atb::infer::SliceParam sliceParam; + sliceParam.offsets = {0, rank * param.hiddenSizePerRank}; + sliceParam.size = {-1, param.hiddenSizePerRank}; + sliceNode.inTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_norm")); + sliceNode.outTensorIds.push_back(GetTensorIdx(tensorMap, "intermediate_norm_per_rank")); + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(sliceParam, &sliceNode.operation)); + opGraph.nodes.push_back(sliceNode); + } + atb::Node obfCalNode; + atb_speed::common::ObfuscationCalculateParam obfCalParam; + obfCalParam.fd = param.modelConfuscationFd; + obfCalParam.hiddenSizePerRank = param.hiddenSizePerRank; + obfCalNode.inTensorIds.push_back(GetTensorIdx(tensorMap, + isMultiRank ? "intermediate_norm_per_rank" : "intermediate_norm")); + obfCalNode.outTensorIds.push_back(GetTensorIdx(tensorMap, + isMultiRank ? "intermediate_pmcc_gather" : "intermediate_pmcc")); + obfCalNode.operation = new atb_speed::common::ObfuscationCalculateOperation("ObfCalculateOperation", obfCalParam); + opGraph.nodes.push_back(obfCalNode); + if (isMultiRank) { + atb::Node allGatherNode; + atb::infer::AllGatherParam allGatherParam; + allGatherParam.rank = param.modelObfuscationParallelInfo.rank; + allGatherParam.rankSize = param.modelObfuscationParallelInfo.worldSize; + allGatherParam.backend = param.modelObfuscationParallelInfo.backend; + allGatherParam.rankTableFile = param.modelObfuscationParallelInfo.rankTableFile; + allGatherParam.hcclComm = param.modelObfuscationParallelInfo.hcommInfo; + allGatherParam.commDomain = param.modelObfuscationParallelInfo.commDomain; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(allGatherParam, &allGatherNode.operation)); + allGatherNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_pmcc_gather")}; + allGatherNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_pmcc_gather_out")}; + opGraph.nodes.push_back(allGatherNode); + + atb::Node transposeNode; + atb::infer::TransposeParam transposeParam; + transposeParam.perm = {1, 0, 2}; + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(transposeParam, &transposeNode.operation)); + transposeNode.inTensorIds = {GetTensorIdx(tensorMap, "intermediate_pmcc_gather_out")}; + transposeNode.outTensorIds = {GetTensorIdx(tensorMap, "intermediate_pmcc")}; + opGraph.nodes.push_back(transposeNode); + } + return atb::NO_ERROR; +} + +template +atb::Status NormLinear(const NormLinearParam ¶m, atb::Operation **operation) +{ + atb::GraphParam opGraph; + opGraph.name = "NormLinear"; + + std::map tensorMap = ConstructNormTensorMap( + param, opGraph.inTensorNum, opGraph.outTensorNum, opGraph.internalTensorNum); + + // 校验(当前不支持PostNorm) + bool normIsPostNorm = static_cast(param.normParamType.layerType) == \ + static_cast(atb::infer::RmsNormParam::RmsNormType::RMS_NORM_POSTNORM) || \ + static_cast(param.normParamType.layerType) == \ + static_cast(atb::infer::LayerNormParam::LayerNormType::LAYER_NORM_POSTNORM); + bool normQuantIsPostNorm = static_cast(param.normQuantParamType.layerType) == \ + static_cast(atb::infer::RmsNormParam::RmsNormType::RMS_NORM_POSTNORM) || \ + static_cast(param.normQuantParamType.layerType) == \ + static_cast(atb::infer::LayerNormParam::LayerNormType::LAYER_NORM_POSTNORM); + if (normIsPostNorm || (UseNormQuant(param) && normQuantIsPostNorm)) { + ATB_SPEED_LOG_ERROR("Common Op NormLinear not support POSTNORM"); + return atb::ERROR_INTERNAL_ERROR; + } + + if (!param.skipNorm) { CHECK_OPERATION_STATUS_RETURN(InsertNorm(opGraph, param, tensorMap)); } + if (param.enableModelConfuscation) { + CHECK_OPERATION_STATUS_RETURN(InsertObfuscationCalculate(opGraph, param, tensorMap)); + } + + atb::Node linearNode; + atb_speed::common::FusionLinearParam linearParam = param.fusionLinearParam; + // 如果是LINEAR_W8A8_DYNAMIC_DEQUANT且AddRmsNormQuant为false, 则设为LinearQuantType::LINEAR_W8A8_DYNAMIC_QUANT + if (linearParam.quantType == LinearQuantType::LINEAR_W8A8_DYNAMIC_DEQUANT && !param.enableAddNorm) { + linearParam.quantType = LinearQuantType::LINEAR_W8A8_DYNAMIC_QUANT; + } + // if LINEAR_W4A8_DYNAMIC_DEQUANT and AddRmsNormQuant is false, then set to LINEAR_W4A8_DYNAMIC_QUANT, + if (linearParam.quantType == LinearQuantType::LINEAR_W4A8_DYNAMIC_DEQUANT && !param.enableAddNorm) { + linearParam.quantType = LinearQuantType::LINEAR_W4A8_DYNAMIC_QUANT; + } + CHECK_OPERATION_STATUS_RETURN(FusionLinear(linearParam, &linearNode.operation)); + std::vector linearInTensor = { + param.skipNorm ? "in_input" : (param.enableModelConfuscation ? "intermediate_pmcc" : "intermediate_norm"), + "in_linear_weight", "in_scale", "in_offset", "in_descale", "in_bias", "in_compress_idx" + }; + if (UseAddRmsNormDynamicQuant(param)) { linearInTensor.push_back("scale1_out"); } + if (param.fusionLinearParam.supportLora) { + if (param.fusionLinearParam.useImMask) { linearInTensor.push_back("in_im_mask"); } + linearInTensor.push_back("in_seq_len_cum_sum"); + linearInTensor.push_back("in_linear_lora_a"); + linearInTensor.push_back("in_linear_lora_b"); + } + if (param.fusionLinearParam.enableFlashComm) { + linearInTensor.push_back("send_counts"); + linearInTensor.push_back("sdispls"); + linearInTensor.push_back("send_count"); + linearInTensor.push_back("recv_counts"); + linearInTensor.push_back("rdispls"); + linearInTensor.push_back("recv_count"); + linearInTensor.push_back("fake_ag_shape"); + } + linearNode.inTensorIds = GetTensorIdxList(tensorMap, linearInTensor); + linearNode.outTensorIds = {GetTensorIdx(tensorMap, "out_linear")}; + if (param.enableModelConfuscation && param.modelObfuscationParallelInfo.worldSize > 1) { + linearNode.inTensorReshapeFuncs.resize(linearNode.inTensorIds.size()); + linearNode.inTensorReshapeFuncs[0] = [=] (const atb::Dims &oldShape, atb::Dims &newShape) { + newShape.dimNum = 2; // 2: dim num + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = oldShape.dims[1] * oldShape.dims[2]; // 2: dim + }; + } + opGraph.nodes.push_back(linearNode); + + CHECK_OPERATION_STATUS_RETURN(atb::CreateOperation(opGraph, operation)); + return atb::NO_ERROR; +} + +LinearQuantType GetLinearQuantType( + const int &packQuantType, const int &linearType, bool hasNorm, const int &linearDesc +) +{ + if (linearType == atb_speed::common::LinearType::FP || \ + linearDesc == LinearDesc::FLOAT16_DESC || \ + linearDesc == LinearDesc::BFLOAT16_DESC) { + return atb_speed::common::LinearQuantType::NO_QUANT; + } else if (packQuantType == atb_speed::common::ALL_W4A16 || \ + packQuantType == atb_speed::common::ALL_W4A16_ANTI || \ + packQuantType == atb_speed::common::MIX_W4A16 || \ + packQuantType == atb_speed::common::MIX_W4A16_ANTI || \ + linearDesc == LinearDesc::W4A16_DESC + ) { + return atb_speed::common::LinearQuantType::W4A16; + } else if (packQuantType == atb_speed::common::ALL_W8A16 || \ + packQuantType == atb_speed::common::ALL_W8A16_ANTI || \ + packQuantType == atb_speed::common::MIX_W8A16 || \ + packQuantType == atb_speed::common::MIX_W8A16_ANTI || \ + linearDesc == LinearDesc::W8A16_DESC + ) { + return atb_speed::common::LinearQuantType::W8A16; + } else if ( + packQuantType == atb_speed::common::ALL_W8A8_DYNAMIC || \ + packQuantType == atb_speed::common::ALL_W8A8_DYNAMIC_ANTI || \ + packQuantType == atb_speed::common::MIX_W8A8_DYNAMIC || \ + packQuantType == atb_speed::common::MIX_W8A8_DYNAMIC_ANTI || \ + linearDesc == LinearDesc::W8A8_DYNAMIC_DESC + ) { + return hasNorm ? LinearQuantType::LINEAR_W8A8_DYNAMIC_DEQUANT : LinearQuantType::LINEAR_W8A8_DYNAMIC_QUANT; + } else if (packQuantType == atb_speed::common::ALL_W4A8 || \ + packQuantType == atb_speed::common::ALL_W4A8_ANTI || \ + packQuantType == atb_speed::common::MIX_W4A8 || \ + packQuantType == atb_speed::common::MIX_W4A8_ANTI || \ + linearDesc == LinearDesc::W4A8_DESC + ) { + return hasNorm ? LinearQuantType::LINEAR_W4A8_DYNAMIC_DEQUANT : LinearQuantType::LINEAR_W4A8_DYNAMIC_QUANT; + } else { + if (packQuantType == atb_speed::common::ALL_W8A8SC || \ + packQuantType == atb_speed::common::MIX_W8A8SC || \ + packQuantType == atb_speed::common::ALL_W8A8SC_ANTI || \ + packQuantType == atb_speed::common::MIX_W8A8SC_ANTI || \ + linearDesc == LinearDesc::W8A8SC_DESC + ) { + return hasNorm ? LinearQuantType::LINEAR_W8A8_SC_DEQUANT : LinearQuantType::LINEAR_W8A8_SC_QUANT; + } else { + return hasNorm ? LinearQuantType::LINEAR_W8A8_DEQUANT : LinearQuantType::LINEAR_W8A8_QUANT; + } + } +} + +template bool UseNormQuant(const NormLinearParam ¶m); +template bool UseAddRmsNormQuant(const NormLinearParam ¶m); +template bool UseAddRmsNormDynamicQuant(const NormLinearParam ¶m); +template std::map ConstructNormTensorMap( + const NormLinearParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum); +template int64_t InsertNorm( + atb::GraphParam &opGraph, const NormLinearParam ¶m, + std::map &tensorMap); +template atb::Status InsertObfuscationCalculate(atb::GraphParam &opGraph, + const NormLinearParam ¶m, std::map &tensorMap); +template atb::Status NormLinear(const NormLinearParam ¶m, atb::Operation **operation); + +template bool UseNormQuant(const NormLinearParam ¶m); +template std::map ConstructNormTensorMap( + const NormLinearParam ¶m, + uint32_t &inTensorNum, uint32_t &outTensorNum, uint32_t &internalTensorNum); +template int64_t InsertNorm( + atb::GraphParam &opGraph, const NormLinearParam ¶m, + std::map &tensorMap); +template atb::Status InsertObfuscationCalculate(atb::GraphParam &opGraph, + const NormLinearParam ¶m, std::map &tensorMap); +template atb::Status NormLinear(const NormLinearParam ¶m, atb::Operation **operation); + +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/norm/norm_linear.h b/tests/proftest/layer_test_framework/operations/fusion/norm/norm_linear.h new file mode 100644 index 00000000..68fddc63 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/norm/norm_linear.h @@ -0,0 +1,157 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ASCEND_SPEED_INFERENCE_COMMON_NORM_LINEAR_H +#define ASCEND_SPEED_INFERENCE_COMMON_NORM_LINEAR_H + +#include +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/utils/operation_util.h" +#include "operations/fusion/linear/linear.h" +#include "operations/fusion/linear/linear_parallel.h" +#include "operations/fusion/utils.h" + +namespace atb_speed { +namespace common { + +/// Parameters for the normalization and linear module +/// \tparam NormParamType Types of the normalization parameters. Avaliable options are `atb::infer::RmsNormParam` +/// and `atb::infer::LayerNormParam`. +template +struct NormLinearParam { + /// A flag indicating whether anti-outlier is enabled + bool isAntiOutlier = false; + /// A flag indicating whether normalization is skipped + bool skipNorm = false; + /// A flag indicating whether normalization has bias + bool normHasBias = false; + /// A flag indicating whether to use the AddNorm fusion operation + bool enableAddNorm = false; + /// Normalization parameters for float operation + NormParamType normParamType; + /// Normlization parameters for quantization operation + NormParamType normQuantParamType; + /// Parameters for the FusionLinear module + atb_speed::common::FusionLinearParam fusionLinearParam; + /// A flag indicating whether to use pmcc obfuscation + bool enableModelConfuscation = false; + /// A handle used by pmcc model obfuscation + int32_t modelConfuscationFd = 0; + /// Hidden size per rank + int32_t hiddenSizePerRank = 0; + /// Parallel info, now only used by pmcc obfuscation in multirank scenario + atb_speed::common::TensorParallelInfo modelObfuscationParallelInfo; +}; + +/// Get `LinearQuantType` by the quantization type of the linear modules and the position of the linear module +/// \param packQuantType The quantization type of the packed linear modules. Refer to `PackQuantType` +/// in the `operations/utils.h`. +/// \param linearType The type of one linear module. Refer to `LinearType` in the `operations/utils.h`. +/// \param hasNorm A flag indicating whether the linear module includes a preceding normalization module +LinearQuantType GetLinearQuantType( + const int &packQuantType = PackQuantType::PACK_QUANT_UNDEFINED, + const int &linearType = LinearType::INVALID, + bool hasNorm = false, + const int &linearDesc = LinearDesc::INVALID_DESC); + +/// The function construct an operation that combines a normalization module with a `FusionLinear` module. +/// +/// \tparam NormParamType Types of the normalization parameters. Avaliable options are `atb::infer::RmsNormParam` +/// and `atb::infer::LayerNormParam`. +/// \param param Parameters for the normalization and linear module +/// \param operation the address of a pointer to a default operation +/// \return A flag that indicates whether operation has been successfully created. +/// +/// Operation's inputs: +/// Name | Dtype | Shape | Description | +/// -------------------|------------------|-------|----------| +/// in_input | float16/bfloat16 | [m,k] | | +/// in_norm_weight | float16/bfloat16 | [k] | | +/// in_norm_bias | float16/bfloat16 | [k] | Used when `param.normHasBias` is true | +/// in_norm_new_weight | float16/bfloat16 | [k] | Used when `param.isAntiOutlier` is true | +/// in_norm_new_bias | float16 | [1] | Used when `param.normHasBias` and `param.isAntiOutlier` is true | +/// in_linear_weight | | | The same specifications as the FusionLinear module | +/// in_scale | | | The same specifications as the FusionLinear module | +/// in_offset | | | The same specifications as the FusionLinear module | +/// in_descale | | | The same specifications as the FusionLinear module | +/// in_bias | | | The same specifications as the FusionLinear module | +/// in_compress_idx | | | The same specifications as the FusionLinear module | +/// in_residual_input | float16/bfloat16 | [m,k] | Used when `enableAddNorm` is true | +/// in_seq_len_cum_sum | | | The same specifications as the FusionLinear module | +/// in_linear_lora_a | | | The same specifications as the FusionLinear module | +/// in_linear_lora_b | | | The same specifications as the FusionLinear module | +/// in_im_mask | | | The same specifications as the FusionLinear module | +/// +/// Operations's outputs: +/// Name | Dtype | Shape | Description | +/// -----------|------------------|-------|----------| +/// out_linear | float16/bfloat16 | [m,n] | Output tensor of the linear module | +/// out_add | float16/bfloat16 | [m,k] | Output tensor of the residual add. Exist when `enableAddNorm` is true. | +/// +/// Example: +/// \code +/// enum TensorIdx : uint32_t { +/// IN_INPUT = 0, +/// IN_NORM_WEIGHT, +/// IN_NORM_BIAS, +/// IN_LINEAR_WEIGHT, +/// IN_PLACEHOLDER, +/// OUT, +/// }; +/// +/// atb::Node normLinearNode; +/// atb_speed::common::NormLinearParam normLinearParam; +/// // Modify normLinearParam's attribute if needed. +/// NormLinear(normLinearParam, &normLinearNode.operation); +/// normLinearNode.inTensorIds = {IN_INPUT, IN_NORM_WEIGHT, IN_NORM_BIAS, IN_PLACEHOLDER, IN_PLACEHOLDER, +/// IN_LINEAR_WEIGHT, IN_PLACEHOLDER, IN_PLACEHOLDER, IN_PLACEHOLDER, IN_PLACEHOLDER, IN_PLACEHOLDER}; +/// normLinearNode.outTensorIds = {OUT}; +/// atb::GraphParam opGraph; +/// opGraph.nodes.push_back(normLinearNode); +/// \endcode +template +atb::Status NormLinear(const NormLinearParam ¶m, atb::Operation **operation); + +/// This function add a normalization node into the graph. +/// +/// \tparam NormParamType Types of the normalization parameters. Avaliable options are `atb::infer::RmsNormParam` +/// and `atb::infer::LayerNormParam`. +/// \param opGraph the graph to be constructed +/// \param param Parameters for the normalization and linear module +/// \param tensorMap A map contains all the required tensors for the node, with the key representing +/// the input tensor name and the value corresponding to the tensor index. This map is used to identify the node's input +/// and output tensors base on tensor names. +/// \return A flag that indicates whether operation has been successfully added to the graph. +/// +/// Example: +/// \code +/// atb_speed::common::NormLinearParam normParam; +/// // Modify normParam's attribute if needed. +/// std::map normTensorMap = std::vector targetNames = { +/// {"in_input", 0}, {"in_norm_weight", 1}, {"in_norm_bias", 2}, {"in_norm_new_weight", 3}, {"in_norm_new_bias", 4}, +/// {"in_scale", 5}, {"in_offset", 6}, {"intermediate_norm", 7}, {"out_add", 8}, {"in_residual_input", 9} +///}; +/// atb::GraphParam graph; +/// atb_speed::common::InsertNorm(graph, normParam, normTensorMap); +/// \endcode +template +int64_t InsertNorm(atb::GraphParam &opGraph, const NormLinearParam ¶m, + std::map &tensorMap); + +} // namespace common +} // namespace atb_speed + +#endif diff --git a/tests/proftest/layer_test_framework/operations/fusion/parallel_info.cpp b/tests/proftest/layer_test_framework/operations/fusion/parallel_info.cpp new file mode 100644 index 00000000..4a53b838 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/parallel_info.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atb_speed/utils/singleton.h" +#include "atb_speed/base/external_comm_manager.h" +#include "parallel_info.h" + +namespace atb_speed { +namespace common { + +std::string InitCommBackend(uint32_t localWorldSize, const std::vector rankIds, std::string commBackend) +{ + if (localWorldSize <= 0) { + throw std::runtime_error("Number of devices in the current node is less than or equal to 0."); + } + // Get backend + std::string backend = commBackend; + // change to hccl if the communication channel across nodes + int32_t currentDevice = -1; + for (uint32_t item : rankIds) { + if (currentDevice != -1 && static_cast(ceil(item / localWorldSize)) != currentDevice) { + backend = "hccl"; + break; + } + currentDevice = static_cast(ceil(item / localWorldSize)); + } + // The hccl backend is utilized in the single node scenario + // when a rankTableFile is supplied and the communication channel spans the entire world size. + uint32_t worldSize = GetSingleton().worldSize_; + if (worldSize <= localWorldSize && GetSingleton().rankTableFile_ != "" && \ + rankIds.size() == worldSize) { + backend = "hccl"; + } + return backend; +} + +void ParallelInfo::InitCommDomain(HcclComm& hcclComm, std::string& commDomain, std::string backend) const +{ + if (backend == "") { + backend = this->defaultBackend; + } + // Get current stream id + uint32_t streamId = GetSingleton().GetStreamId(); + + // Assign commDomain by rankIds and rank + commDomain = GetSingleton().GetCommDomain( + this->groupId, this->rankIds, this->rank, backend, this->bufferSize, streamId); + // Get hcclComm (only created when hccl backend is used and inference across multi nodes) + hcclComm = GetSingleton().GetCommPtr(commDomain); + + ATB_SPEED_LOG_DEBUG(this->ToString()); +} + +bool ParallelInfo::IsEnabled() const +{ + return this->rankIds.size() > 1; +} + +std::string ParallelInfo::ToString() const +{ + std::stringstream ss; + ss << "ParallelInfo: rank: " << this->rank + << ", rankIds: " << this->rankIds + << ", groupId: " << this->groupId + << ", defaultBackend: " << this->defaultBackend + << ", bufferSize: " << this->bufferSize; + return ss.str(); +} + +} // namespace common +} // namesapce atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/parallel_info.h b/tests/proftest/layer_test_framework/operations/fusion/parallel_info.h new file mode 100644 index 00000000..b44efb38 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/parallel_info.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_PARALLEL_INFO_H +#define ATB_SPEED_PARALLEL_INFO_H +#include +#include "atb_speed/log.h" +#include "operations/fusion/utils.h" + +namespace atb_speed { +namespace common { +/// Parameters related to parallelism +struct ParallelInfo { + /// Rank of the device within the communication group + uint32_t rank = 0; + /// Size of the communication group + std::vector rankIds = {}; + /// Index of the current communication groups + uint32_t groupId = 0; + /// Backend communication method + std::string defaultBackend = ""; + /// The size of the buffer area for sharing data between devices + uint32_t bufferSize = 0; + + /// Initialize hccl communication handle on demand and get unique communication domain from rankIds + void InitCommDomain(HcclComm& hcclComm, std::string& commDomain, std::string backend = "") const; + /// Check if the parallel strategy is enabled + bool IsEnabled() const; + /// A summary of the `ParallelInfo` object + std::string ToString() const; +}; + +std::string InitCommBackend(uint32_t localWorldSize, const std::vector rankIds, std::string commBackend); + +/// Parameters related to pipeline parallelism +struct PpParallelInfo : public ParallelInfo { + /// Micro batch size + int microBatchSize = 1; + /// Parameters related to the internal tensor parallelism + ParallelInfo internalTp; +}; + +} // namespace common +} // namespace atb_speed + +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/parallel_layer.cpp b/tests/proftest/layer_test_framework/operations/fusion/parallel_layer.cpp new file mode 100644 index 00000000..501aa278 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/parallel_layer.cpp @@ -0,0 +1,163 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "parallel_layer.h" + +#include +#include + +#include "atb_speed/log.h" + +namespace atb_speed { +namespace common { +enum ParallelType : int { + ROW_PARALLEL = 0, + COLUMN_PARALLEL, +}; + +atb::Status InnerParallelLinearBase(const ParallelParam ¶m, atb::GraphParam &opGraph, + const ParallelType parallelType) +{ + if (parallelType == ROW_PARALLEL) { + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + auto dimNum = inTensorDescs.at(0).shape.dimNum; + outTensorDescs.at(0).shape.dimNum = dimNum; + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(0).shape.dims[0]; + if (dimNum == 3) { // 维度数量 3 + outTensorDescs.at(0).shape.dims[1] = inTensorDescs.at(0).shape.dims[1]; + } + if (param.transposeB) { + outTensorDescs.at(0).shape.dims[dimNum - 1] = inTensorDescs.at(1).shape.dims[0]; + } else { + outTensorDescs.at(0).shape.dims[dimNum - 1] = inTensorDescs.at(1).shape.dims[1]; + } + return atb::NO_ERROR; + }; + } else { + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + auto dimNum = inTensorDescs.at(0).shape.dimNum; + outTensorDescs.at(0).shape.dimNum = dimNum + 1; // add rank dim + outTensorDescs.at(0).shape.dims[0] = param.rankSize; + outTensorDescs.at(0).shape.dims[1] = inTensorDescs.at(0).shape.dims[0]; + if (dimNum == 3) { // 维度数量 3 + outTensorDescs.at(0).shape.dims[2] = inTensorDescs.at(0).shape.dims[1]; // dim 2 + } + if (param.transposeB) { + outTensorDescs.at(0).shape.dims[dimNum] = inTensorDescs.at(1).shape.dims[0]; // last dim + } else { + outTensorDescs.at(0).shape.dims[dimNum] = inTensorDescs.at(1).shape.dims[1]; // last dim + } + return atb::NO_ERROR; + }; + } + return atb::NO_ERROR; +} + +template +atb::Status ParallelLinearBase(const ParallelParam ¶m, atb::Operation **operation, T config, + const ParallelType parallelType) +{ + atb::GraphParam opGraph; + opGraph.name = "ParallelLinearBase"; + opGraph.inTensorNum = static_cast(config.inTensorNum); + opGraph.outTensorNum = static_cast(config.outTensorNum); + opGraph.internalTensorNum = static_cast(config.interTensorNum); + opGraph.nodes.resize(config.nodeCount); + + size_t nodeId = 0; + atb::Node &matmulNode = opGraph.nodes.at(nodeId++); + + atb::infer::LinearParam matmulParam = { param.transposeA, param.transposeB, false }; + CREATE_OPERATION(matmulParam, &matmulNode.operation); + matmulNode.inTensorIds = { config.IN_INPUT, config.IN_WEIGHT }; + matmulNode.outTensorIds = { config.INTERMIDATE_MATMULOUT }; + + if (param.rankSize > 1) { + atb::Node ¶llelNode = opGraph.nodes.at(nodeId++); + + if (parallelType == ROW_PARALLEL) { + atb::infer::AllReduceParam allReduceParam; + allReduceParam.rank = param.rank; + allReduceParam.rankSize = param.rankSize; + allReduceParam.backend = param.backend; + CREATE_OPERATION(allReduceParam, ¶llelNode.operation); + } else { + atb::infer::AllGatherParam allGatherParam; + allGatherParam.rank = param.rank; + allGatherParam.rankSize = param.rankSize; + allGatherParam.backend = param.backend; + CREATE_OPERATION(allGatherParam, ¶llelNode.operation); + } + + parallelNode.inTensorIds = { config.INTERMIDATE_MATMULOUT }; + parallelNode.outTensorIds = { config.INTERMIDATE_ALLREDUCEOUT }; + } + + if (param.isBias) { + atb::Node &addNode = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam addParam; + addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CREATE_OPERATION(addParam, &addNode.operation); + addNode.inTensorIds = { param.rankSize > 1 ? config.INTERMIDATE_ALLREDUCEOUT : config.INTERMIDATE_MATMULOUT, + config.IN_BIAS }; + addNode.outTensorIds = { config.OUT_LINEAROUT }; + } + InnerParallelLinearBase(param, opGraph, parallelType); + CREATE_OPERATION(opGraph, operation); + return atb::NO_ERROR; +} + +atb::Status ParallelLinear(const ParallelParam ¶m, atb::Operation **operation, const ParallelType parallelType) +{ + if (param.isBias && (param.rankSize > 1)) { + return ParallelLinearBase(param, operation, + LinearWithBiasAndParallel(3, 1, 2, 3), // 3是输入张量数量 1是输出张量数量 2是中间张量数量 3是节点数量 + parallelType); // 3:in 1:out 2:inter 3:node + } else if (param.isBias) { + return ParallelLinearBase(param, operation, + LinearWithBias(3, 1, 1, 2), // 3是输入张量数量 1是输出张量数量 1是中间张量数量 2是节点数量 + parallelType); // 3:in 1:out 1:inter 2:node + } else if (param.rankSize > 1) { + return ParallelLinearBase(param, operation, + LinearWithParallel(2, 1, 1, 2), // 2是输入张量数量 1是输出张量数量 1是中间张量数量 2是节点数量 + parallelType); // 2:in 1:out 1:inter 2:node + } else { + return ParallelLinearBase(param, operation, LinearOnly(2, 1, 0, 1), parallelType); // 2:in 1:out 0:inter 1:node + } +} + +atb::Status RowParallelLinear(const ParallelParam ¶m, atb::Operation **operation) +{ + return ParallelLinear(param, operation, ROW_PARALLEL); +} + +atb::Status ColumnParallelLinear(const ParallelParam ¶m, atb::Operation **operation) +{ + return ParallelLinear(param, operation, COLUMN_PARALLEL); +} + +atb::Status VocabParallelEmbedding(const atb::Operation **operation) +{ + (void)&operation; + return 0; +} +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/fusion/parallel_layer.h b/tests/proftest/layer_test_framework/operations/fusion/parallel_layer.h new file mode 100644 index 00000000..6f80c40e --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/parallel_layer.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_LAYERS_PARALLEL_LAYER_H +#define ATB_SPEED_LAYERS_PARALLEL_LAYER_H +#include +#include "atb_speed/log.h" +#include "atb_speed/utils/operation_util.h" +#include "nlohmann/json.hpp" +#include "common_op_base.h" + +namespace atb_speed { +namespace common { +struct ParallelParam { + int rank = 0; + int rankSize = 1; + int rankRoot = 0; + void *hcclComm = nullptr; + bool isBias = false; + bool transposeA = false; + bool transposeB = true; + std::string backend = "hccl"; + bool isBF16 = false; +}; + +class LinearWithBiasAndParallel : public CommonOpBase { +public: + using CommonOpBase::CommonOpBase; + + enum LinearWithBiasAndParallelId : unsigned int { + IN_INPUT = 0, + IN_WEIGHT, + IN_BIAS, + OUT_LINEAROUT, + INTERMIDATE_MATMULOUT, + INTERMIDATE_ALLREDUCEOUT, + }; +}; + +class LinearWithParallel : public CommonOpBase { +public: + using CommonOpBase::CommonOpBase; + + enum LinearWithParallelId : unsigned int { + IN_INPUT = 0, + IN_WEIGHT, + INTERMIDATE_ALLREDUCEOUT, + INTERMIDATE_MATMULOUT, + IN_BIAS, + OUT_LINEAROUT, + }; +}; + +class LinearWithBias : public CommonOpBase { +public: + using CommonOpBase::CommonOpBase; + + enum LinearWithBiasId : unsigned int { + IN_INPUT = 0, + IN_WEIGHT, + IN_BIAS, + OUT_LINEAROUT, + INTERMIDATE_MATMULOUT, + INTERMIDATE_ALLREDUCEOUT, + }; +}; + +class LinearOnly : public CommonOpBase { +public: + using CommonOpBase::CommonOpBase; + + enum LinearOnlyId : unsigned int { + IN_INPUT = 0, + IN_WEIGHT, + INTERMIDATE_MATMULOUT, + IN_BIAS, + OUT_LINEAROUT, + INTERMIDATE_ALLREDUCEOUT, + }; +}; + +atb::Status RowParallelLinear(const ParallelParam ¶m, atb::Operation **operation); +atb::Status ColumnParallelLinear(const ParallelParam ¶m, atb::Operation **operation); +atb::Status VocabParallelEmbedding(const atb::Operation **operation); +} // namespace common +} // namespace atb_speed + +#endif diff --git a/tests/proftest/layer_test_framework/operations/fusion/parallel_layer_v2.cpp b/tests/proftest/layer_test_framework/operations/fusion/parallel_layer_v2.cpp new file mode 100644 index 00000000..37e25943 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/parallel_layer_v2.cpp @@ -0,0 +1,289 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nlohmann/json.hpp" +#include "parallel_layer_v2.h" + + +namespace atb_speed { +namespace common { +enum ParallelType : int { + ROW_PARALLEL = 0, + COLUMN_PARALLEL, +}; + +enum InTensorId : int { + IN_INPUT = 0, + IN_WEIGHT, + IN_BIAS, + IN_DEQSCALE, + IN_INDEX_IDS, + IN_OFFSET, + IN_COMPRESSINFO, + OUT_LINEAR, + INTER_ID, +}; + +atb::Status CalNodeNum(size_t &nodeCount, size_t &internalTensorNum, const ParallelParamV2 ¶m, + const ParallelType parallelType) +{ + if (param.isQuant) { + if (param.quantParam.isQuantOp) { + nodeCount += 1; + internalTensorNum += 1; + } + } else { + if (param.isBias) { + nodeCount += 1; + internalTensorNum += 1; + } + } + + if (param.commParam.rankSize > 1) { + nodeCount += 1; + internalTensorNum += 1; + if (parallelType == COLUMN_PARALLEL && param.isAllGatherTranspose) { + nodeCount += 1; + internalTensorNum += 1; + } + } + return atb::NO_ERROR; +} + +atb::Status AddmatmulNode(const ParallelParamV2 ¶m, atb::GraphParam &opGraph, size_t &nodeId, uint32_t &inteId) +{ + if (!param.isQuant) { + atb::Node &matmulNode = opGraph.nodes.at(nodeId++); + atb::infer::LinearParam matmulParam = { param.transposeA, param.transposeB, false }; + CREATE_OPERATION(matmulParam, &matmulNode.operation); + matmulNode.inTensorIds = { IN_INPUT, IN_WEIGHT }; + matmulNode.outTensorIds = { (param.commParam.rankSize > 1 || param.isBias) ? + inteId : + static_cast(OUT_LINEAR) }; + } else { + if (param.quantParam.isQuantOp) { + atb::Node &quantNode = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam quantParam; + quantParam.elewiseType = param.quantParam.elewiseType; + quantParam.quantParam.inputScale = param.quantParam.inputScale; + quantParam.quantParam.inputOffset = param.quantParam.inputOffset; + CREATE_OPERATION(quantParam, &quantNode.operation); + quantNode.inTensorIds = { IN_INPUT }; + quantNode.outTensorIds = { inteId }; + } + + if (param.isSparse) { + atb::Node &matmulNode = opGraph.nodes.at(nodeId++); + atb::infer::LinearSparseParam linearSparseParam = { false, true, 8, 8 }; // 8 压缩参数 + CREATE_OPERATION(linearSparseParam, &matmulNode.operation); + matmulNode.inTensorIds = { param.quantParam.isQuantOp ? inteId++ : static_cast(IN_INPUT), + static_cast(IN_WEIGHT), static_cast(IN_BIAS), static_cast(IN_DEQSCALE), + static_cast(IN_INDEX_IDS) }; + matmulNode.outTensorIds = { param.commParam.rankSize > 1 ? inteId : static_cast(OUT_LINEAR) }; + } else { + atb::Node &matmulNode = opGraph.nodes.at(nodeId++); + atb::infer::LinearParam matmulParam; + matmulParam.transposeA = param.transposeA; + matmulParam.transposeB = param.transposeB; + matmulParam.outDataType = ACL_FLOAT16; + CREATE_OPERATION(matmulParam, &matmulNode.operation); + matmulNode.inTensorIds = { param.quantParam.isQuantOp ? inteId++ : static_cast(IN_INPUT), + static_cast(IN_WEIGHT), static_cast(IN_BIAS), static_cast(IN_DEQSCALE) }; + matmulNode.outTensorIds = { param.commParam.rankSize > 1 ? inteId : static_cast(OUT_LINEAR) }; + } + } + return atb::NO_ERROR; +} + +atb::Status CalMulRank(const ParallelParamV2 ¶m, atb::GraphParam &opGraph, size_t &nodeId, uint32_t &inteId, + const ParallelType parallelType) +{ + if (param.commParam.rankSize > 1) { + atb::Node ¶llelNode = opGraph.nodes.at(nodeId++); + + if (parallelType == ROW_PARALLEL) { + atb::infer::AllReduceParam allReduceParam; + allReduceParam.rank = param.commParam.rank; + allReduceParam.rankSize = param.commParam.rankSize; + allReduceParam.backend = param.commParam.backend; + CREATE_OPERATION(allReduceParam, ¶llelNode.operation); + parallelNode.inTensorIds = { inteId++ }; + parallelNode.outTensorIds = { param.isBias && !param.isQuant ? inteId : static_cast(OUT_LINEAR) }; + } else { + atb::infer::AllGatherParam allGatherParam; + allGatherParam.rank = param.commParam.rank; + allGatherParam.rankSize = param.commParam.rankSize; + allGatherParam.backend = param.commParam.backend; + CREATE_OPERATION(allGatherParam, ¶llelNode.operation); + parallelNode.inTensorIds = { inteId++ }; + parallelNode.outTensorIds = { (param.isBias && !param.isQuant) || param.isAllGatherTranspose ? + inteId : + static_cast(OUT_LINEAR) }; + + // (world_size,bs,seq,vocab_size//world_size) + // -> (bs,seq,world_size,vocab_size//world_size) + // -> (bs,seq,vocab_size) + if (param.isAllGatherTranspose) { + atb::Node &gatherTransposeNode = opGraph.nodes.at(nodeId++); + atb::infer::TransposeParam gatherTransposeParam; + gatherTransposeParam.perm = { 1, 2, 0, 3 }; + CREATE_OPERATION(gatherTransposeParam, &gatherTransposeNode.operation); + gatherTransposeNode.inTensorIds = { inteId++ }; + gatherTransposeNode.outTensorIds = { param.isBias && !param.isQuant ? + inteId : + static_cast(OUT_LINEAR) }; + } + } + } + return atb::NO_ERROR; +} + +atb::Status CalBias(const ParallelParamV2 ¶m, atb::GraphParam &opGraph, size_t &nodeId, const uint32_t &inteId) +{ + if (param.isBias && !param.isQuant) { + atb::Node &addNode = opGraph.nodes.at(nodeId++); + atb::infer::ElewiseParam addParam; + addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CREATE_OPERATION(addParam, &addNode.operation); + addNode.inTensorIds = { inteId, IN_BIAS }; + addNode.outTensorIds = { OUT_LINEAR }; + } + return atb::NO_ERROR; +} + +atb::Status RowParallelInferShape(const ParallelParamV2 ¶m, atb::GraphParam &opGraph) +{ + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + if (param.isQuant) { + outTensorDescs.at(0).dtype = ACL_FLOAT16; + } else { + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + } + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + auto dimNum = inTensorDescs.at(0).shape.dimNum; + auto wdim = inTensorDescs.at(1).shape.dimNum; + outTensorDescs.at(0).shape.dimNum = dimNum; + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(0).shape.dims[0]; + if (param.isQuant && param.isSparse) { + if (dimNum == 3) { // 3是张量的维度数 + outTensorDescs.at(0).shape.dims[1] = inTensorDescs.at(0).shape.dims[1]; + outTensorDescs.at(0).shape.dims[2] = inTensorDescs.at(2).shape.dims[0]; // 2 dim维度数下标 2 下标 + } else { + outTensorDescs.at(0).shape.dims[1] = inTensorDescs.at(2).shape.dims[0]; // 2 下标 + } + } else if (param.isQuant) { + if (dimNum == 3) { // 3是张量的维度数 + outTensorDescs.at(0).shape.dims[1] = inTensorDescs.at(0).shape.dims[1]; + } + outTensorDescs.at(0).shape.dims[dimNum - 1] = inTensorDescs.at(1).shape.dims[wdim - 2]; // ND,NZ统一为-2轴 + } else { + if (dimNum == 3) { // 3是张量的维度数 + outTensorDescs.at(0).shape.dims[1] = inTensorDescs.at(0).shape.dims[1]; + } + outTensorDescs.at(0).shape.dims[dimNum - 1] = inTensorDescs.at(1).shape.dims[0]; + } + return atb::NO_ERROR; + }; + return atb::NO_ERROR; +} + +atb::Status NoRowParallelInferShape(const ParallelParamV2 ¶m, atb::GraphParam &opGraph) +{ + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + if (!param.isQuant) { + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + } else { + outTensorDescs.at(0).dtype = ACL_FLOAT16; + } + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + auto dimNum = inTensorDescs.at(0).shape.dimNum; + if (param.isAllGatherTranspose) { + outTensorDescs.at(0).shape.dimNum = dimNum; + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(0).shape.dims[0]; + if (dimNum == 3) { // 3是张量的维度数 + outTensorDescs.at(0).shape.dims[1] = inTensorDescs.at(0).shape.dims[1]; // dim 2 + } + outTensorDescs.at(0).shape.dims[dimNum - 1] = + inTensorDescs.at(1).shape.dims[0] * param.commParam.rankSize; // last dim + } else { + outTensorDescs.at(0).shape.dimNum = dimNum + 1; // add rank dim + outTensorDescs.at(0).shape.dims[0] = param.commParam.rankSize; + outTensorDescs.at(0).shape.dims[1] = inTensorDescs.at(0).shape.dims[0]; + if (dimNum == 3) { // 3是张量的维度数 + outTensorDescs.at(0).shape.dims[2] = inTensorDescs.at(0).shape.dims[1]; // dim 2 + } + outTensorDescs.at(0).shape.dims[dimNum] = inTensorDescs.at(1).shape.dims[0]; // last dim + } + + return atb::NO_ERROR; + }; + return atb::NO_ERROR; +} + +atb::Status ParallelLinearBaseV2(const ParallelParamV2 ¶m, atb::Operation **operation, + const ParallelType parallelType) +{ + atb::GraphParam opGraph; + opGraph.name = "ParallelLinearBaseV2"; + opGraph.inTensorNum = 7; // 7是输入张量数量 + opGraph.outTensorNum = 1; + // 判断node个数 + size_t nodeCount = 1; + size_t internalTensorNum = 0; + CHECK_OPERATION_STATUS_RETURN(CalNodeNum(nodeCount, internalTensorNum, param, parallelType)); + opGraph.internalTensorNum = internalTensorNum; + opGraph.nodes.resize(nodeCount); + size_t nodeId = 0; + uint32_t inteId = INTER_ID; + CHECK_OPERATION_STATUS_RETURN(AddmatmulNode(param, opGraph, nodeId, inteId)); + CHECK_OPERATION_STATUS_RETURN(CalMulRank(param, opGraph, nodeId, inteId, parallelType)); + CHECK_OPERATION_STATUS_RETURN(CalBias(param, opGraph, nodeId, inteId)); + if (parallelType == ROW_PARALLEL) { + CHECK_OPERATION_STATUS_RETURN(RowParallelInferShape(param, opGraph)); + } else { + CHECK_OPERATION_STATUS_RETURN(NoRowParallelInferShape(param, opGraph)); + } + CREATE_OPERATION(opGraph, operation); + return atb::NO_ERROR; +} + + +atb::Status ParallelLinearV2(const ParallelParamV2 ¶m, atb::Operation **operation, const ParallelType parallelType) +{ + return ParallelLinearBaseV2(param, operation, parallelType); // 5:in 1:out 3:inter +} + + +atb::Status RowParallelLinearV2(const ParallelParamV2 ¶m, atb::Operation **operation) +{ + return ParallelLinearV2(param, operation, ROW_PARALLEL); +} + +atb::Status ColumnParallelLinearV2(const ParallelParamV2 ¶m, atb::Operation **operation) +{ + return ParallelLinearV2(param, operation, COLUMN_PARALLEL); +} + +atb::Status VocabParallelEmbeddingV2(const atb::Operation **operation) +{ + (void)operation; + return 0; +} +} // namespace common +} // namespace atb_speed diff --git a/tests/proftest/layer_test_framework/operations/fusion/parallel_layer_v2.h b/tests/proftest/layer_test_framework/operations/fusion/parallel_layer_v2.h new file mode 100644 index 00000000..8591831c --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/parallel_layer_v2.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_LAYER_PARALLEL_LAYER_V2_H +#define ATB_SPEED_LAYER_PARALLEL_LAYER_V2_H +#include +#include "atb_speed/utils/operation_util.h" +#include "nlohmann/json.hpp" + +namespace atb_speed { +namespace common { + +struct QuantParam { + atb::infer::QuantType quantType; + atb::infer::ElewiseParam::ElewiseType elewiseType; + float inputScale = 1.0f; + int inputOffset = 0; + int tilingN = 0; + int tilingK = 0; + bool isQuantOp = false; +}; + +struct CommParam { + int rank = 0; + int rankSize = 1; + int rankRoot = 0; + void *hcclComm = nullptr; + std::string backend = "hccl"; +}; + +struct ParallelParamV2 { + bool isBias = false; + bool transposeA = false; + bool transposeB = true; + bool isQuant = false; + bool isSparse = false; + bool isAllGatherTranspose = false; + bool isBF16 = false; + CommParam commParam; + QuantParam quantParam; +}; + +atb::Status RowParallelLinearV2(const ParallelParamV2 ¶m, atb::Operation **operation); +atb::Status ColumnParallelLinearV2(const ParallelParamV2 ¶m, atb::Operation **operation); +atb::Status VocabParallelEmbeddingV2(const atb::Operation **operation); +} // namespace common +} // namespace atb_speed + +#endif diff --git a/tests/proftest/layer_test_framework/operations/fusion/parallel_lmhead.cpp b/tests/proftest/layer_test_framework/operations/fusion/parallel_lmhead.cpp new file mode 100644 index 00000000..b9660026 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/parallel_lmhead.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "parallel_lmhead.h" + +#include + +#include "parallel_layer.h" + +namespace atb_speed { +namespace common { +atb::Status InnerInferShape(const ParallelLmHeadParam ¶m, atb::GraphParam &opGraph) +{ + if (param.gatherAhead) { + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + auto dimLast = inTensorDescs.at(0).shape.dimNum - 1; + outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(2).shape.dims[0]; // 2 下标 + outTensorDescs.at(0).shape.dims[dimLast] = inTensorDescs.at(1).shape.dims[0] * param.rankSize; + return atb::NO_ERROR; + }; + } else { + opGraph.inferShapeFunc = [=](const atb::SVector &inTensorDescs, + atb::SVector &outTensorDescs) { + outTensorDescs.at(0) = inTensorDescs.at(0); + auto dimLast = inTensorDescs.at(0).shape.dimNum - 1; + outTensorDescs.at(0).shape.dims[dimLast] = inTensorDescs.at(1).shape.dims[0] * param.rankSize; + return atb::NO_ERROR; + }; + } + return atb::NO_ERROR; +} + +template +atb::Status CreateParallelLmHeadBase(const ParallelLmHeadParam ¶m, atb::Operation **operation, T config) +{ + atb::GraphParam opGraph; + opGraph.inTensorNum = config.inTensorNum; + opGraph.outTensorNum = config.outTensorNum; + opGraph.internalTensorNum = config.interTensorNum; + opGraph.nodes.resize(config.nodeCount); + if (param.gatherAhead) { + opGraph.name = "Parallel_LmHead_GatherAhead"; + } else { + opGraph.name = "Parallel_LmHead"; + } + size_t nodeId = 0; + if (param.gatherAhead) { + auto &gatherNode = opGraph.nodes.at(nodeId++); + atb::infer::GatherParam gatherParam; + CREATE_OPERATION(gatherParam, &gatherNode.operation); + gatherNode.inTensorIds = { config.IN_HIDDENSTATES_ID, config.IN_LMHEAD_INDICES_ID }; + gatherNode.outTensorIds = { config.INTERMIDATE_GATHER_OUT_ID }; + } + atb::Node ¶llelLinearNode = opGraph.nodes.at(nodeId++); + atb_speed::common::ParallelParam parallelParam; + parallelParam.rank = param.rank; + parallelParam.rankSize = param.rankSize; + parallelParam.isBias = false; + parallelParam.backend = param.backend; + parallelParam.isBF16 = param.isBF16; + atb_speed::common::ColumnParallelLinear(parallelParam, ¶llelLinearNode.operation); + parallelLinearNode.inTensorIds = { param.gatherAhead ? config.INTERMIDATE_GATHER_OUT_ID : config.IN_HIDDENSTATES_ID, + config.IN_WEIGHT_ID }; + parallelLinearNode.outTensorIds = { param.rankSize > 1 ? config.INTERMEDIATE_ALLGATHER_OUT_ID : + config.OUT_LOGITS_ID }; + if (param.rankSize > 1) { + atb::Node &transposeNode = opGraph.nodes.at(nodeId++); + atb::infer::TransposeParam transposeParam; + if (param.unpadInputs) { + transposeParam.perm = { 1, 0, 2 }; // 2 是维度重新排列顺序 + } else { + transposeParam.perm = { 1, 2, 0, 3 }; // 2 3 是维度重新排列顺序 + } + CREATE_OPERATION(transposeParam, &transposeNode.operation); + transposeNode.inTensorIds = { config.INTERMEDIATE_ALLGATHER_OUT_ID }; + transposeNode.outTensorIds = { config.OUT_LOGITS_ID }; + } + CHECK_OPERATION_STATUS_RETURN(InnerInferShape(param, opGraph)); + CREATE_OPERATION(opGraph, operation); + return atb::NO_ERROR; +} + +atb::Status ParallelLmHead(const ParallelLmHeadParam ¶m, atb::Operation **operation) +{ + if (param.rankSize > 1) { + if (!param.gatherAhead) { + return CreateParallelLmHeadBase(param, operation, + ParallelLmHeadConfig(2, 1, 1, 2)); // 2 输入张量 1 输出张量 1 中间张量 2 节点 数量 + } else if (param.unpadInputs) { + return CreateParallelLmHeadBase(param, operation, + ParallelLmHeadGatherAheadConfig(3, 1, 2, 3)); // 3 输入张量 1 输出张量 2 中间张量 3 节点 数量 + } else { + ATB_SPEED_LOG_ERROR("[gatherAhead] can only used with [unpadInputs]"); + return atb::ERROR_INVALID_PARAM; + } + } else { + if (!param.gatherAhead) { + return CreateParallelLmHeadBase(param, operation, + ParallelLmHeadConfig(2, 1, 0, 1)); // 2 输入张量 1 输出张量 0 中间张量 1 节点 数量 + } else if (param.unpadInputs) { + return CreateParallelLmHeadBase(param, operation, + ParallelLmHeadGatherAheadConfig(3, 1, 1, 2)); // 3 输入张量 1 输出张量 1 中间张量 2 节点 数量 + } else { + ATB_SPEED_LOG_ERROR("[gatherAhead] can only used with [unpadInputs]"); + return atb::ERROR_INVALID_PARAM; + } + } +} +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/parallel_lmhead.h b/tests/proftest/layer_test_framework/operations/fusion/parallel_lmhead.h new file mode 100644 index 00000000..868cda2f --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/parallel_lmhead.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATB_SPEED_LAYERS_PARALLEL_LMHEAD_LAYER_H +#define ATB_SPEED_LAYERS_PARALLEL_LMHEAD_LAYER_H +#include +#include + +#include "atb_speed/log.h" +#include "atb_speed/utils/operation_util.h" +#include "common_op_base.h" + +namespace atb_speed { +namespace common { +struct ParallelLmHeadParam { + int rank = 0; + int rankSize = 1; + std::string backend = "hccl"; + bool unpadInputs = false; + bool gatherAhead = false; + bool transposeA = false; + bool transposeB = true; + bool isBF16 = false; +}; + +class ParallelLmHeadConfig : public CommonOpBase { +public: + using CommonOpBase::CommonOpBase; + + enum ParallelLmHeadId : unsigned int { + IN_HIDDENSTATES_ID = 0, + IN_WEIGHT_ID, + OUT_LOGITS_ID, + INTERMEDIATE_ALLGATHER_OUT_ID, + IN_LMHEAD_INDICES_ID, + INTERMIDATE_GATHER_OUT_ID, + }; +}; + +class ParallelLmHeadGatherAheadConfig : public CommonOpBase { +public: + using CommonOpBase::CommonOpBase; + + enum ParallelLmHeadGatherAheadId : unsigned int { + IN_HIDDENSTATES_ID = 0, + IN_WEIGHT_ID, + IN_LMHEAD_INDICES_ID, + OUT_LOGITS_ID, + INTERMIDATE_GATHER_OUT_ID, + INTERMEDIATE_ALLGATHER_OUT_ID, + }; +}; + +atb::Status ParallelLmHead(const ParallelLmHeadParam ¶m, atb::Operation **operation); +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/utils.cpp b/tests/proftest/layer_test_framework/operations/fusion/utils.cpp new file mode 100644 index 00000000..3dc09b76 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/utils.cpp @@ -0,0 +1,378 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operations/fusion/utils.h" +#include "atb_speed/base/event_manager.h" +#include "atb_speed/utils/operation_util.h" +#include "atb_speed/utils/singleton.h" + +namespace atb_speed { +namespace common { + +const std::string CMO_COMPUTE = "cmo_compute"; +const std::string CMO_OPROJ = "cmo_oproj"; +const std::string CMO_MLAPO = "cmo_mlapo"; +const std::string CV_START = "cv_start"; +const std::string VECTOR_CONTROL = "vector_control"; +const std::string CUBE_CONTROL = "cube_control"; +const std::string COMPUTE_EVENT = "compute"; +const std::string COMM_EVENT = "comm"; +const std::string END_EVENT = "end"; +const std::string CC_START = "cc_start"; +const std::string COMM_CONTROL = "comm_control"; +const std::string COMP_CONTROL = "compute_control"; + +void DapManager::SetRole(DapRole role) { this->currentRole = role; } + +DapRole DapManager::GetRole() const { return this->currentRole; } + +std::string DapManager::GetSuccessorSuffix() const +{ + return "_successor"; +} + +uint32_t DapManager::GetStreamId() const +{ + return this->currentRole == DapRole::SUCCESSOR ? 1 : 0; +} + +int32_t CommOpCounter::Increment() +{ + DapRole currentRole = GetSingleton().GetRole(); + std::map::iterator it = this->count.find(currentRole); + if (it == this->count.end()) { + this->count[currentRole] = 1; + return 1; + } + + int ¤tRoleCount = it->second; + currentRoleCount += 1; + return currentRoleCount; +} + +int32_t CommOpCounter::GetCount() +{ + DapRole currentRole = GetSingleton().GetRole(); + std::map::iterator it = this->count.find(currentRole); + if (it == this->count.end()) { + return 0; + } + int ¤tRoleCount = it->second; + return currentRoleCount; +} + +void CommOpCounter::Reset() +{ + std::map::iterator it; + for (it = this->count.begin(); it != this->count.end(); it++) { + it->second = 0; + } +} + +atb::Status AddDapEventsBeforeComm(atb::GraphParam &opGraph) +{ + DapRole dapRole = GetSingleton().GetRole(); + atb_speed::EventAction actionType = + dapRole == DapRole::PRECEDER ? atb_speed::EventAction::PUSH : atb_speed::EventAction::POP; + std::stringstream ss; + std::string role = dapRole == DapRole::PRECEDER ? "PRECEDER" : "SUCCESSOR"; + std::string action = actionType == atb_speed::EventAction::PUSH ? "PUSH" : "POP"; + if (dapRole != DapRole::UNDEFINED_ROLE) { + atb::Node computeRecordNode; + computeRecordNode.inTensorIds = {}; + computeRecordNode.outTensorIds = {}; + CHECK_OPERATION_STATUS_RETURN( + atb_speed::EventManager::GetInstance().RecordEvent(computeRecordNode.operation, actionType, COMPUTE_EVENT)); + opGraph.nodes.push_back(computeRecordNode); + ss.str(""); + ss << "[Events] [" << role << "] [" << action << "] [RECORD] [COMPUTE]"; + ATB_SPEED_LOG_DEBUG(ss.str()); + + if (!(dapRole == DapRole::PRECEDER && GetSingleton().GetCount() == 0)) { + atb::Node commWaitNode; + commWaitNode.inTensorIds = {}; + commWaitNode.outTensorIds = {}; + CHECK_OPERATION_STATUS_RETURN( + atb_speed::EventManager::GetInstance().WaitEvent(commWaitNode.operation, actionType, COMM_EVENT)); + opGraph.nodes.push_back(commWaitNode); + ss.str(""); + ss << "[Events] [" << role << "] [" << action << "] [WAIT] [COMM]"; + ATB_SPEED_LOG_DEBUG(ss.str()); + } + } + return atb::NO_ERROR; +}; + +atb::Status AddDapEventsAfterComm(atb::GraphParam &opGraph) +{ + DapRole dapRole = GetSingleton().GetRole(); + atb_speed::EventAction actionType = + dapRole == DapRole::PRECEDER ? atb_speed::EventAction::PUSH : atb_speed::EventAction::POP; + std::stringstream ss; + std::string role = dapRole == DapRole::PRECEDER ? "PRECEDER" : "SUCCESSOR"; + std::string action = actionType == atb_speed::EventAction::PUSH ? "PUSH" : "POP"; + if (dapRole != DapRole::UNDEFINED_ROLE) { + atb::Node commRecordNode; + commRecordNode.inTensorIds = {}; + commRecordNode.outTensorIds = {}; + CHECK_OPERATION_STATUS_RETURN( + atb_speed::EventManager::GetInstance().RecordEvent(commRecordNode.operation, actionType, COMM_EVENT)); + opGraph.nodes.push_back(commRecordNode); + ss.str(""); + ss << "[Events] [" << role << "] [" << action << "] [RECORD] [COMM]"; + ATB_SPEED_LOG_DEBUG(ss.str()); + + atb::Node computeWaitNode; + computeWaitNode.inTensorIds = {}; + computeWaitNode.outTensorIds = {}; + CHECK_OPERATION_STATUS_RETURN( + atb_speed::EventManager::GetInstance().WaitEvent(computeWaitNode.operation, actionType, COMPUTE_EVENT)); + opGraph.nodes.push_back(computeWaitNode); + ss.str(""); + ss << "[Events] [" << role << "] [" << action << "] [WAIT] [COMPUTE]"; + ATB_SPEED_LOG_DEBUG(ss.str()); + } + + GetSingleton().Increment(); + return atb::NO_ERROR; +}; + +void AssignTensorIdx( + const std::map> &tensorCandidates, + std::string targetKey, uint32_t &tensorIdx, std::map &tensorMap) +{ + if (tensorCandidates.find(targetKey) == tensorCandidates.end()) { + ATB_SPEED_LOG_WARN("targetKey: " << targetKey << " not found in tensorCandidates"); + return; + } + + for (std::string tensor : tensorCandidates.at(targetKey)) { + tensorMap[tensor] = tensorIdx; + tensorIdx++; + } +} + +void AssignTensorIdx( + const std::map> &tensorCandidates, + std::string targetKey, std::map &tensorMap) +{ + if (tensorCandidates.find(targetKey) == tensorCandidates.end()) { + ATB_SPEED_LOG_WARN("targetKey: " << targetKey << " not found in tensorCandidates"); + return; + } + + uint32_t startTensorIdx = tensorMap.size(); + for (std::string tensor : tensorCandidates.at(targetKey)) { + tensorMap[tensor] = startTensorIdx; + startTensorIdx++; + } +} + +template +void AddTensorToList( + const std::map &tensorCandidates, + std::string targetKey, T &tensorList) +{ + if (tensorCandidates.find(targetKey) == tensorCandidates.end()) { + ATB_SPEED_LOG_WARN("targetKey: " << targetKey << " not found in tensorCandidates"); + return; + } + + for (const auto& item : tensorCandidates.at(targetKey)) { + tensorList.push_back(item); + } +} + +std::map GetTensorMap( + std::vector &inTensorList, std::vector &outTensorList, + std::vector &intermediateTensorList) +{ + std::map tensorMap = {}; + uint32_t tensorIdx = 0; + + // 添加inTensor + for (const auto &tensor : inTensorList) { + tensorMap[tensor] = tensorIdx; + tensorIdx++; + } + + // 添加outTensor + for (const auto &tensor : outTensorList) { + tensorMap[tensor] = tensorIdx; + tensorIdx++; + } + + // 添加intermediateTensor + for (const auto &tensor : intermediateTensorList) { + tensorMap[tensor] = tensorIdx; + tensorIdx++; + } + + std::stringstream ss; + for (auto tensor = tensorMap.cbegin(); tensor != tensorMap.cend(); ++tensor) { + ss << "tensor name: " << tensor->first << ", tensor id: " << tensor->second << std::endl; + } + ATB_SPEED_LOG_DEBUG("tensor map\n" << ss.str()); + + return tensorMap; +} + +uint32_t GetTensorIdx(const std::map &tensorMap, std::string tensorName) +{ + if (tensorMap.find(tensorName) == tensorMap.end()) { + ATB_SPEED_LOG_DEBUG("Cannot find " << tensorName << " in tensor Map"); + return UINT32_MAX; + } + return tensorMap.at(tensorName); +} + +atb::SVector GetTensorIdxList(const std::map &tensorMap, + std::vectortensorNames) +{ + atb::SVector tensorIdxList = {}; + for (std::string tensorName : tensorNames) { + tensorIdxList.push_back(GetTensorIdx(tensorMap, tensorName)); + } + return tensorIdxList; +} + +bool CheckAntiOutlier(const int &packQuantType) +{ + bool isAntiOutlier = packQuantType == atb_speed::common::MIX_W8A8_ANTI || \ + packQuantType == atb_speed::common::ALL_W8A8_ANTI || \ + packQuantType == atb_speed::common::ALL_W8A8SC_ANTI || \ + packQuantType == atb_speed::common::MIX_W8A8SC_ANTI || \ + packQuantType == atb_speed::common::ALL_W8A16_ANTI || \ + packQuantType == atb_speed::common::MIX_W8A16_ANTI || \ + packQuantType == atb_speed::common::ALL_W4A16_ANTI || \ + packQuantType == atb_speed::common::MIX_W4A16_ANTI || \ + packQuantType == atb_speed::common::ALL_W8A8_DYNAMIC_ANTI || \ + packQuantType == atb_speed::common::MIX_W8A8_DYNAMIC_ANTI || \ + packQuantType == atb_speed::common::ALL_W4A8_ANTI || \ + packQuantType == atb_speed::common::MIX_W4A8_ANTI; + return isAntiOutlier; +} + +bool CheckPack(const int &packQuantType, const std::vector &linearDescs, const std::vector &linearIndex) +{ + static const std::unordered_set PackableQuantTypes = { + atb_speed::common::ALL_FP, + atb_speed::common::ALL_W8A16, atb_speed::common::ALL_W8A16_ANTI, + atb_speed::common::ALL_W4A16, atb_speed::common::ALL_W4A16_ANTI, + atb_speed::common::ALL_W8A8, atb_speed::common::ALL_W8A8_ANTI, + atb_speed::common::ALL_W8A8SC, atb_speed::common::ALL_W8A8SC_ANTI, + atb_speed::common::ALL_W8A8_DYNAMIC, atb_speed::common::ALL_W8A8_DYNAMIC_ANTI, + atb_speed::common::ALL_W4A8, atb_speed::common::ALL_W4A8_ANTI + }; + + // "packable" packQuantType + if (PackableQuantTypes.count(packQuantType)) { + return true; + } + // "unpackable" packQuantType + if (packQuantType != atb_speed::common::PACK_QUANT_UNDEFINED) { + return false; + } + // undefined packQuantType, check pack from linearDescs (assume the first desc to be valid) + int currentDesc = LinearDesc::INVALID_DESC; + for (const int &index : linearIndex) { + if (index >= static_cast(linearDescs.size())) { + ATB_SPEED_LOG_WARN(index << " out of range in CheckPack"); + continue; + } + int desc = linearDescs.at(index); + // skip invalid desc, usually placeholder for packed linear + if (desc == LinearDesc::INVALID_DESC) { + continue; + } + // init with first valid desc + if (currentDesc == LinearDesc::INVALID_DESC) { + currentDesc = desc; + } else if (desc != currentDesc) { + // if valid and differ from prev descs -> unpackable + return false; + } + } + return true; +} + +atb::Status CheckParamVectorSize(const std::vector &vector, size_t threshold) +{ + if (vector.size() < threshold) { + return atb::ERROR_INVALID_PARAM; + } + return atb::NO_ERROR; +} + +atb::Status CreateRecordWithoutNodeId(atb::GraphParam &opGraph, + atb_speed::EventAction eventAction, const std::string &cvKey) +{ + atb::Node recordNode; + recordNode.inTensorIds = {}; + recordNode.outTensorIds = {}; + CHECK_OPERATION_STATUS_RETURN(atb_speed::EventManager::GetInstance().RecordEvent( + recordNode.operation, + eventAction, + cvKey)); + opGraph.nodes.push_back(recordNode); + ATB_SPEED_LOG_DEBUG("Record event success"); + return atb::NO_ERROR; +} + +atb::Status CreateWaitWithoutNodeId(atb::GraphParam &opGraph, + atb_speed::EventAction eventAction, const std::string &cvKey) +{ + atb::Node waitNode; + waitNode.inTensorIds = {}; + waitNode.outTensorIds = {}; + CHECK_OPERATION_STATUS_RETURN(atb_speed::EventManager::GetInstance().WaitEvent( + waitNode.operation, + eventAction, + cvKey)); + opGraph.nodes.push_back(waitNode); + ATB_SPEED_LOG_DEBUG("Wait event success"); + return atb::NO_ERROR; +} + +PackQuantType ConvertQuantTypeToPackType(std::string quantType) +{ + const std::unordered_map quantTypeToPackType = { + {"float", atb_speed::common::PackQuantType::ALL_FP}, + {"w8a8", atb_speed::common::PackQuantType::ALL_W8A8}, + {"w8a8s", atb_speed::common::PackQuantType::ALL_W8A8}, + {"w8a8sc", atb_speed::common::PackQuantType::ALL_W8A8SC}, + {"w8a8_dynamic", atb_speed::common::PackQuantType::ALL_W8A8_DYNAMIC}, + {"w8a16", atb_speed::common::PackQuantType::ALL_W8A16}, + {"w4a16", atb_speed::common::PackQuantType::ALL_W4A16}, + {"", atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED}, + }; + + auto it = quantTypeToPackType.find(quantType); + if (it == quantTypeToPackType.end()) { + return atb_speed::common::PackQuantType::PACK_QUANT_UNDEFINED; + } + + return it->second; +} + +template void AddTensorToList( + const std::map> &tensorCandidates, + std::string targetKey, std::vector &tensorList); +template void AddTensorToList( + const std::map> &tensorCandidates, + std::string targetKey, atb::SVector &tensorList); +} // namespace common +} // namespace atb_speed \ No newline at end of file diff --git a/tests/proftest/layer_test_framework/operations/fusion/utils.h b/tests/proftest/layer_test_framework/operations/fusion/utils.h new file mode 100644 index 00000000..89107164 --- /dev/null +++ b/tests/proftest/layer_test_framework/operations/fusion/utils.h @@ -0,0 +1,317 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_SPEED_MODELS_COMMON_UITLS_H +#define ATB_SPEED_MODELS_COMMON_UITLS_H + +#include +#include +#include +#include +#include "atb_speed/log.h" +#include "atb_speed/base/event_manager.h" + + +namespace atb_speed { +namespace common { + +/// The pack and quantization type of linear operations. +/// Q, k and v linear may be packed. Gate and up linear may be packed. +/// +/// Each value except `PACK_QUANT_UNDEFINED` represents a combination of pack type and quantization type. +/// Explaination of each key word: +/// - `PACK_QUANT_UNDEFIEND`: undefined pack and quantization type. +/// - `ALL`: all linear in the pack are using same quantization, weights will be combined to accelerate computation. +/// - `MIX`: linears in the pack are mixture of quantization and float. Computation will be performed separately. +/// - `W8A8`: weights and activation values are quantized to int8. +/// - `W4A8`: weights are quantized to int4 and activation values are quantized to int8. +/// - `W8A16`: weights are quantized to int8 and activation values are holded in float16/bfloat16. +/// - `W4A16`: weights are quantized to int4 and activation values are holded in float16/bfloat16. +/// - `ANTI`: quantization with anti-outlier. +/// - `SC`: quantization with model sparsing and compression. Exclusively supported by Atlas 300I Duo. +/// - `DYNAMIC`: using per-token quantization. +enum PackQuantType : unsigned int { + PACK_QUANT_UNDEFINED = 0, + ALL_FP = 1, + ALL_W8A8 = 2, + ALL_W8A8_ANTI = 3, + MIX_W8A8 = 4, + MIX_W8A8_ANTI = 5, + ALL_W8A16 = 6, + ALL_W8A8SC = 7, + MIX_W8A8SC = 8, + ALL_W8A8SC_ANTI = 9, + MIX_W8A8SC_ANTI = 10, + ALL_W4A16 = 11, + ALL_W8A16_ANTI = 12, + ALL_W4A16_ANTI = 13, + MIX_W4A16 = 14, + MIX_W4A16_ANTI = 15, + MIX_W8A16 = 16, + MIX_W8A16_ANTI = 17, + ALL_W8A8_DYNAMIC = 18, + ALL_W8A8_DYNAMIC_ANTI = 19, + MIX_W8A8_DYNAMIC = 20, + MIX_W8A8_DYNAMIC_ANTI = 21, + ALL_W4A8 = 22, + MIX_W4A8 = 23, + ALL_W4A8_ANTI = 24, + MIX_W4A8_ANTI = 25 +}; + +/// An listing of operations backend. +enum OpBackend: unsigned int { + /// Ascend Transformer Boost backend. + ATB = 0, + /// Ascend Computing Language Neural Network backend. + ACLNN = 1, +}; + +/// An enumeration of quantization types for the linear operations. +enum LinearQuantType : unsigned int { + /// No quantization. + NO_QUANT = 0, + /// Weights are quantized to int8 and activation values are quantized to int8. + /// Quantization is performed in normalization operation and dequantization is performed in linear operation. + LINEAR_W8A8_DEQUANT, + /// Weights are quantized to int8 and activation values are quantized to int8. + /// Quantization and dequantization are both performed in linear operation. + LINEAR_W8A8_QUANT, + /// Weights are quantized to int4. + /// Quantization and dequantization are both performed in linear operation. + W4A16, + /// Weights are quantized to int8. + /// Quantization and dequantization are both performed in linear operation. + W8A16, + /// Weights are quantized to int8 and activation values are quantized to int8, using sparse compression. + /// Quantization is performed in normalization operation and dequantization is performed in linear operation. + LINEAR_W8A8_SC_DEQUANT, + /// Weights are quantized to int8 and activation values are quantized to int8, using sparse compression. + /// Quantization and dequantization are both performed in linear operation. + LINEAR_W8A8_SC_QUANT, + /// Weights are quantized to int8 and activation values are quantized to int8, using per-token quantization. + /// Quantization is performed in normalization operation and dequantization is performed in linear operation. + LINEAR_W8A8_DYNAMIC_DEQUANT, + /// Weights are quantized to int8 and activation values are quantized to int8, using per-token quantization.. + /// Quantization and dequantization are both performed in linear operation. + LINEAR_W8A8_DYNAMIC_QUANT, + LINEAR_W4A8_DYNAMIC_DEQUANT, + LINEAR_W4A8_DYNAMIC_QUANT +}; + +/// An enum of linear dtype. +enum LinearType : int { + /// Invalid type. + INVALID = -1, + /// Float type. + FP = 0, + /// Integer type. + INT = 1, +}; + +enum LinearDesc : int { + INVALID_DESC = -1, + FLOAT16_DESC = 0, + BFLOAT16_DESC = 1, + W4A16_DESC = 2, + W8A16_DESC = 3, + W8A8_PER_TENSOR_DESC = 4, + W8A8S_DESC = 5, + W8A8SC_DESC = 6, + W8A8_DYNAMIC_DESC = 7, + W8A8_PDMIX_DESC = 8, + W4A8_DESC = 9 +}; + +/// Transpose type of B matrix in matmul operation. +enum TransposeType : int { + /// Invalid type. + TRANSPOSE_INVALID = -1, + /// Do not transpose B matrix in matmul operation. + NOT_TRANSPOSE = 0, + /// Do transpose B matrix in matmul operation. + TRANSPOSE = 1, +}; + +/// Details about tensor parallelism +/// +/// Parameters will be directly passed to `AllReduceOperation`, `AllGatherOperation` +/// or `LinearParallelOperation` defined in `atb/atb_infer.h`. +struct TensorParallelInfo { + /// Rank of the current process + int rank = 0; + /// Number of processes participating in the job + int worldSize = 1; + /// Communication backend. Options: `hccl`, `lccl` + std::string backend = "hccl"; + /// Path of the cluster information config file. Use for single-node or multi-node communcation. + std::string rankTableFile = ""; + HcclComm hcommInfo = nullptr; + /// A communication device group is identified by a communication domain name. + std::string commDomain = ""; + /// Quant type + atb::infer::AllReduceParam::QuantType quantType = atb::infer::AllReduceParam::QuantType::QUANT_TYPE_UNDEFINED; + /// The data type of the output tensor + aclDataType outDataType = ACL_DT_UNDEFINED; +}; + +extern const std::string CMO_COMPUTE; +extern const std::string CMO_OPROJ; +extern const std::string CMO_MLAPO; +extern const std::string CV_START; +extern const std::string VECTOR_CONTROL; +extern const std::string CUBE_CONTROL; +extern const std::string COMPUTE_EVENT; +extern const std::string COMM_EVENT; +extern const std::string END_EVENT; +extern const std::string CC_START; +extern const std::string COMM_CONTROL; +extern const std::string COMP_CONTROL; + +enum DapRole : uint32_t { + UNDEFINED_ROLE = 0, + PRECEDER = 1, + SUCCESSOR = 2, +}; + +class DapManager { +public: + void SetRole(DapRole role); + DapRole GetRole() const; + std::string GetSuccessorSuffix() const; + uint32_t GetStreamId() const; + +private: + DapRole currentRole = DapRole::UNDEFINED_ROLE; +}; + +class CommOpCounter { +public: + int32_t Increment(); + int32_t GetCount(); + void Reset(); + +private: + std::map count = {}; +}; + +atb::Status AddDapEventsBeforeComm(atb::GraphParam &opGraph); + +atb::Status AddDapEventsAfterComm(atb::GraphParam &opGraph); + +/// Assgin indices to tensors based on the provided tensorCandidates and targetKey. +/// Indices start from the provided tensorIndex. +/// +/// \param tensorCanditates A map where each tensor key maps to a list of tensor name. +/// \param targetKey The tensor key to identify which tensor name list to be assigned. +/// \param tensorIdx The initial index to assgin. +/// \param tensorMap A map store the tensor name to index mapping. +void AssignTensorIdx( + const std::map> &tensorCandidates, + std::string targetKey, uint32_t &tensorIdx, std::map &tensorMap); + +/// Assgin indices to tensors based on the provided tensorCandidates and targetKey. +/// Indices start from the size of the existing tensorMap. +/// +/// \param tensorCanditates A map where each tensor key maps to a list of tensor name. +/// \param targetKey The tensor key to identify which tensor list to be assigned. +/// \param tensorMap A map store the tensor name to index mapping. +void AssignTensorIdx( + const std::map> &tensorCandidates, + std::string targetKey, std::map &tensorMap); + +/// Add tensorCandidates.at(targetKey) to tensorList. +/// +/// \tparam T The type of elements in the tensor list. This can be any container type that supports push_back. +/// \param tensorCandidates A map where each tensor key maps to a list of tensor. +/// \param targetKey The tensor key to identify which tensor list to be assigned. +/// \param tensorList A list of type T where to add the tensor. +template +void AddTensorToList( + const std::map &tensorCandidates, + std::string targetKey, T &tensorList); + +/// Return a map of tensor name to tensor index. +/// +/// Assume length of in/out/intermediateTensorList is n1/n2/n3, the returned map is +/// \code +/// { +/// {inTensorList[0], 0}, {inTensorList[1]: 1},..., {inTensorList[n1-1], n1-1}, +/// {outTensorList[0], n1}, {outTensorList[1]: n1+1},..., {outTensorList[n2-1], n1+n2-1}, +/// {intermediateTensorList[0], n1+n2},..., {intermediateTensorList[n3-1], n1+n2+n3-1} +/// }; +/// \endcode +/// +/// \param inTensorList A list of input tensors of an operation. +/// \param outTensorList A list of output tensors of an operation. +/// \param intermediateTensorList A list of intermediate tensors of an operation. +/// \return A map of tensor name to tensor index. +std::map GetTensorMap( + std::vector &inTensorList, std::vector &outTensorList, + std::vector &intermediateTensorList); + +/// Retrieve the tensor index using the `tensorName` from the `tensorMap`. +/// +/// \param tensorMap A map of tensor name to tensor index. +/// \param tensorName The name of the tensor. +/// \return The index of the tensor. +uint32_t GetTensorIdx(const std::map &tensorMap, std::string tensorName); + +/// Return a list of tensor indices from the `tensorMap` referenced by `tensorNames`. +/// +/// \param tensorMap A map of tensor name to tensor index. +/// \param tensorNames A list of tensor names. +/// \return A list of tensor indices. +atb::SVector GetTensorIdxList(const std::map &tensorMap, + std::vectortensorNames); + +/// Verify if `packQuantType` supports quantization with anti-outlier. +/// +/// \param packQuantType The pack and quantization type of linear operations. +/// Refer to `atb_speed::common::PackQuantType` in the `operations/fusion/utils.h` for more details. +/// \return True if `packQuantType` supports quantization with anti-outlier. +bool CheckAntiOutlier(const int &packQuantType); + +/// Check whether linear weights are packed. +/// \param packQuantType The pack and quantization type of linear operations. +/// Refer to `atb_speed::common::PackQuantType` in the `operations/fusion/utils.h` for more details. +/// \param linearDescs weight description of linear module +/// \param linearIndex A list of index of the target linear module +/// \return True if linear weights are packed. +bool CheckPack(const int &packQuantType = PackQuantType::PACK_QUANT_UNDEFINED, + const std::vector &linearDescs = {}, + const std::vector &linearIndex = {}); + +/// Valide the size of `vector`. It should not be smaller than `threshold`. +/// \param vector The vector to be checked. +/// \param threshold The threshold of the size of `vector`. +/// \return A flag indicating whether the size of `vector` is valid. +atb::Status CheckParamVectorSize(const std::vector &vector, size_t threshold); + +atb::Status CreateRecordWithoutNodeId(atb::GraphParam &opGraph, + atb_speed::EventAction eventAction, const std::string &cvKey); + +atb::Status CreateWaitWithoutNodeId(atb::GraphParam &opGraph, + atb_speed::EventAction eventAction, const std::string &cvKey); + +/// Convert quantType to packType. e.g. `"w8a8"` -> `ALL_W8A8`. +/// \param quantType The quantization type. Valid values are `float`/`w8a8`/`w8a8s`/`w8a8sc`/`w8a8_dynamic`/`w8a16`/ +/// `w4a16`, other input values will return `PACK_QUANT_UNDEFINED`. +/// \return The corresponding pack type. +PackQuantType ConvertQuantTypeToPackType(std::string quantType); +} // namespace common +} // namespace atb_speed +#endif \ No newline at end of file diff --git a/tests/proftest/main.cpp b/tests/proftest/main.cpp new file mode 100644 index 00000000..f6789d8a --- /dev/null +++ b/tests/proftest/main.cpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include +#include +#include +#include +#include +#include +// Run the benchmark + +bool isBaseLine = false; + +void loadArgs(const std::vector &args) +{ + for (const auto &arg : args) { + if (arg == "atb_proftest_baseline") { + isBaseLine = true; + } + } +} + +class ProftestReporter : public benchmark::ConsoleReporter { +public: + ProftestReporter(std::string execPath) : execPath(execPath) {} + void ReportRuns(const std::vector &reports) override + { + for (auto &report : reports) { + std::string caseName = report.benchmark_name(); + size_t index = caseName.find('/'); + if (index != std::string::npos) { + caseName = caseName.substr(0, index); + } + std::filesystem::path baselinePath = std::filesystem::weakly_canonical( + std::filesystem::path(execPath + "/../../../../proftest/BaseLine/" + caseName + "_BaseLine.csv")); + std::filesystem::path resultPath = std::filesystem::weakly_canonical( + std::filesystem::path(execPath + "/../../../../proftest/Result/" + caseName + "_Result.csv")); + if (isBaseLine) { + std::filesystem::create_directories(baselinePath.parent_path()); + std::ofstream baselineFile(baselinePath, std::ios::out | std::ios::trunc); + if (!baselineFile.is_open()) { + std::cerr << "Failed to open baseline file: " << baselinePath << std::endl; + exit(1); + } + baselineFile << "Benchmark_Name,Real Time (s),CPU Time (s)\n"; + baselineFile << caseName << "," << report.real_accumulated_time << "," << report.cpu_accumulated_time << "\n"; + baselineFile.close(); + return; + } + + std::filesystem::create_directories(resultPath.parent_path()); + std::ofstream resultFile(resultPath, std::ios::out | std::ios::trunc); + if (!resultFile.is_open()) { + std::cerr << "Failed to open result file: " << resultPath << std::endl; + exit(1); + } + resultFile << "Benchmark_Name,Real Time (s),CPU Time (s)\n"; + resultFile << caseName << "," << report.real_accumulated_time << "," << report.cpu_accumulated_time << "\n"; + resultFile.close(); + + if (std::filesystem::exists(baselinePath)) { + std::ifstream baselineFile(baselinePath, std::ios::in); + if (!baselineFile.is_open()) { + std::cerr << "Failed to open baseline file: " << baselinePath << " , skip compare.\n"; + return; + } + double baselineCpuTime = 0; + std::string line; + while (std::getline(baselineFile, line)) { + std::istringstream iss(line); + std::string benchmarkName, realTime, cpuTime; + if (std::getline(iss, benchmarkName, ',') && std::getline(iss, realTime, ',') && + std::getline(iss, cpuTime, ',')) { + if (benchmarkName == caseName) { + baselineCpuTime = std::stod(cpuTime); + } + } + } + baselineFile.close(); + double resultCpuTime = report.cpu_accumulated_time; + if (resultCpuTime > baselineCpuTime * 1.05) { + std::cerr << "Benchmark " << caseName << " has degraded. Please check:\n"; + std::cerr << "Baseline file: " << baselinePath << "\n"; + std::cerr << "Result file: " << resultPath << "\n"; + exit(1); + } + } + } + + return; + } + +private: + std::string execPath; +}; + + +int main(int argc, char **argv) +{ + // 保存执行文件的路径,用于确定基线和输出文件的目录。 + std::string execPath = argv[0]; + benchmark::MaybeReenterWithoutASLR(argc, argv); + char arg0_default[] = "benchmark"; + char *args_default = reinterpret_cast(arg0_default); + if (!argv) { + argc = 1; + argv = &args_default; + } + { + // 读取自定义参数,并从benchmark的参数中剔除。 + std::vector proftestArgs; + std::vector new_argv; + new_argv.push_back(argv[0]); + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg.substr(0, 13) == "atb_proftest_") { + proftestArgs.push_back(arg); + } else { + new_argv.push_back(argv[i]); + } + } + argc = new_argv.size(); + argv = new_argv.data(); + + loadArgs(proftestArgs); + } + benchmark::Initialize(&argc, argv); + if (benchmark::ReportUnrecognizedArguments(argc, argv)) + return 1; + benchmark::RunSpecifiedBenchmarks(new ProftestReporter(execPath)); + benchmark::Shutdown(); + std::cout << "Result: OK. All case's proftest are done." << std::endl; + return 0; +} +int main(int, char **); \ No newline at end of file diff --git a/tests/proftest/test_cases/bloom_7b/main.cpp b/tests/proftest/test_cases/bloom_7b/main.cpp new file mode 100644 index 00000000..c3393b9e --- /dev/null +++ b/tests/proftest/test_cases/bloom_7b/main.cpp @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include "context_utils.h" +#include "tensor_utils.h" + + +enum class TensorId : int32_t { + ElewiseAddNode0In, +}; +#include "models/base/param/param.h" +#include "models/base/param/layer_param.h" +#include "models/bloom/layer/bloom_decoder_layer.h" +static void Layer_Bloom_7B(benchmark::State &state) +{ + for (auto _ : state) { + state.PauseTiming(); + aclInit(nullptr); + uint32_t deviceId = 0; + aclrtSetDevice(deviceId); + aclrtStream stream; + aclrtCreateStream(&stream); + atb::Context *context = nullptr; + atb::CreateContext(&context); + context->SetExecuteStream(stream); + atb::Operation *graphOperation = nullptr; + { + atb_speed::base::LayerParam layerParam; + layerParam.layerId = 0; + layerParam.numHiddenLayers = 1; + layerParam.isFA = false; + layerParam.isUnpadInputs = true; + layerParam.isPrefill = false; + layerParam.isBF16 = false; + layerParam.isEdgeHardware = false; + layerParam.enableSwiGLU = false; + layerParam.enableLcoc = false; + layerParam.enableMC2 = false; + layerParam.enableSpeculate = false; + layerParam.enableCompressHead = false; + layerParam.enableOmniAttention = false; + layerParam.useQKNorm = false; + layerParam.enableSplitFuse = false; + layerParam.enableLora = false; + layerParam.enablePreFetchWeight = false; + layerParam.loraEnableGMM = false; + layerParam.enableKvQuant = false; + layerParam.enableFA3 = false; + layerParam.kvQuantHasOffset = true; + layerParam.enableReduceQuant = false; + layerParam.enableInterLayerAddNorm = false; + layerParam.enableIntraLayerAddNorm = false; + layerParam.enablePrefixCache = false; + layerParam.attnBackend = atb_speed::common::OpBackend::ATB; + layerParam.matmulBackend = atb_speed::common::OpBackend::ATB; + layerParam.positionEmbeddingType = atb_speed::base::PositionEmbeddingType::ALIBI; + layerParam.normEps = 1e-05; + layerParam.normType = atb_speed::base::NormType::LAYER_NORM; + layerParam.quantGroupSize = 0; + layerParam.numAttentionHeadsPerRank = 32; + layerParam.hiddenSizePerAttentionHead = 128; + layerParam.numKeyValueHeadsPerRank = 32; + layerParam.enableFlashComm = 0; + layerParam.enableModelConfuscation = 0; + layerParam.modelConfuscationFd = 0; + layerParam.packQuantType = {atb_speed::common::PackQuantType::ALL_FP, + atb_speed::common::PackQuantType::ALL_FP}; + layerParam.linearQuantType = { + atb_speed::common::LinearType::FP, atb_speed::common::LinearType::INVALID, + atb_speed::common::LinearType::INVALID, atb_speed::common::LinearType::FP, + atb_speed::common::LinearType::FP, atb_speed::common::LinearType::INVALID, + atb_speed::common::LinearType::FP}; + layerParam.linearTransposeType = {1, -1, -1, 1, 1, -1, 1}; + layerParam.linearHasBias = {true, true, true, true}; + layerParam.weightQuantType = ""; + layerParam.backend = "lccl"; + layerParam.tensorParallelInfo = {0, 1, "lccl", "", nullptr}; + layerParam.hasAttnTp = false; + layerParam.attnTpRank = 0; + layerParam.attnTpSize = 1; + layerParam.attnTpDomain = ""; + layerParam.attnTpRankTableFile = ""; + layerParam.hasAttnDp = false; + layerParam.attnDpRank = 0; + layerParam.attnDpSize = 1; + layerParam.attnDpDomain = ""; + layerParam.attnDpRankTableFile = ""; + layerParam.hasMlpTp = false; + layerParam.mlpTpRank = 0; + layerParam.mlpTpSize = 1; + layerParam.mlpTpDomain = ""; + layerParam.mlpTpRankTableFile = ""; + layerParam.enableSwigluQuant = false; + atb_speed::bloom::BloomDecoderLayer bloomDecoderLayer(layerParam); + bloomDecoderLayer.BuildGraph(&graphOperation); + } + std::vector inTensorDesc{ + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {4096}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {4096}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {12288, 4096}, .dimNum = 2}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {12288}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {4096, 4096}, .dimNum = 2}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {4096}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {4096}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {4096}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {16384, 4096}, .dimNum = 2}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {16384}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {4096, 16384}, .dimNum = 2}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {4096}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1, 4096}, .dimNum = 2}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {32, 1, 4096}, .dimNum = 3}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {9, 128, 32, 128}, .dimNum = 4}}, + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {9, 128, 32, 128}, .dimNum = 4}}, + {.dtype = ACL_INT32, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_INT32, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_INT32, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + {.dtype = ACL_INT32, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1, 1}, .dimNum = 2}}, + {.dtype = ACL_INT32, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, + }; + std::vector inTensor = FillTensorDataByOne(inTensorDesc); + std::vector outTensorDesc{ + {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1, 4096}, .dimNum = 2}}, + }; + std::vector outTensor = FillTensorDataByZero(outTensorDesc); + atb::VariantPack variantPack; + for (size_t i = 0; i < inTensor.size(); ++i) { + variantPack.inTensors.push_back(inTensor[i]); + } + for (size_t i = 0; i < outTensor.size(); ++i) { + variantPack.outTensors.push_back(outTensor[i]); + } + uint64_t workwpaceSize = 0; + graphOperation->Setup(variantPack, workwpaceSize, context); + void *workSpace = nullptr; + aclrtMalloc(&workSpace, workwpaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + state.ResumeTiming(); + graphOperation->Execute(variantPack, (uint8_t *)workSpace, workwpaceSize, context); + aclrtSynchronizeStream(stream); + state.PauseTiming(); + atb::DestroyContext(context); + atb::DestroyOperation(graphOperation); + for (size_t i = 0; i < variantPack.outTensors.size(); ++i) { + PrintDeviceTensor(variantPack.outTensors.at(i)); + } + for (size_t i = 0; i < variantPack.inTensors.size(); ++i) { + FreeTensor(variantPack.inTensors.at(i)); + } + for (size_t i = 0; i < variantPack.outTensors.size(); ++i) { + FreeTensor(variantPack.outTensors.at(i)); + } + + aclrtFree(workSpace); + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + } +} + +BENCHMARK(Layer_Bloom_7B)->Iterations(1); diff --git a/tests/proftest/utils/include/context_utils.h b/tests/proftest/utils/include/context_utils.h new file mode 100644 index 00000000..7080702f --- /dev/null +++ b/tests/proftest/utils/include/context_utils.h @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CONTEXT_UTILS_H +#define CONTEXT_UTILS_H +#include + +#include + +atb::Context *GetDefaultContext(uint32_t deviceId); + +void ReleaseDefaultContext(uint32_t deviceId, atb::Context *context); + +#endif diff --git a/tests/proftest/utils/include/tensor_utils.h b/tests/proftest/utils/include/tensor_utils.h new file mode 100644 index 00000000..b6d5bf0f --- /dev/null +++ b/tests/proftest/utils/include/tensor_utils.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TENSOR_UTILS_H +#define TENSOR_UTILS_H +#include +#include +#include + +atb::Tensor FillTensorDataRandomly(const atb::TensorDesc &desc, float range_min, float range_max); + +atb::Tensor FillTensorDataRandomly(const atb::TensorDesc &desc); + +atb::Tensor FillTensorDataRandomly(const atb::TensorDesc &desc, const std::pair range); + +std::vector FillTensorDataRandomly(const std::vector &descs); + +std::vector FillTensorDataRandomly(const std::vector &descs, + const std::vector &range_mins, + const std::vector &range_maxs); + +std::vector FillTensorDataRandomly(const std::vector &descs, float range_min, + float range_max); + +std::vector FillTensorDataRandomly(const std::vector &descs, + const std::vector> &ranges); + +std::vector FillTensorDataRandomly(const std::vector &descs, + std::pair range); + +atb::Tensor FillTensorDataByZero(const atb::TensorDesc &desc); + +std::vector FillTensorDataByZero(const std::vector &descs); + +atb::Tensor FillTensorDataByOne(const atb::TensorDesc &desc); + +std::vector FillTensorDataByOne(const std::vector &descs); + +atb::Tensor FillTensorDataByFile(const atb::TensorDesc &desc, const std::string &filePath); + +std::vector FillTensorDataByFile(const std::vector &descs, + const std::vector &filePaths); + +void FreeTensor(atb::Tensor &tensor); + +void FreeTensor(std::vector &tensors); + +void PrintDeviceTensor(const atb::Tensor &tensor); + +void PrintDeviceTensor(const std::vector &tensors); + +#endif \ No newline at end of file diff --git a/tests/proftest/utils/include/type_utils.h b/tests/proftest/utils/include/type_utils.h new file mode 100644 index 00000000..270b9749 --- /dev/null +++ b/tests/proftest/utils/include/type_utils.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TYPE_UTILS_H +#define TYPE_UTILS_H +#include + +typedef uint16_t float16; +typedef uint16_t bfloat16; + +float16 FloatToFloat16(float fp32); + +bfloat16 FloatToBfloat16(float fp32); + +float Float16ToFloat(float16 fp16); + +float Bfloat16ToFloat(bfloat16 bf16); + +#endif \ No newline at end of file diff --git a/tests/proftest/utils/src/context_utils.cpp b/tests/proftest/utils/src/context_utils.cpp new file mode 100644 index 00000000..d5728fdd --- /dev/null +++ b/tests/proftest/utils/src/context_utils.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "context_utils.h" + +#include +#include +#include + +atb::Context *GetDefaultContext(uint32_t deviceId) +{ + aclrtSetDevice(deviceId); + atb::Context *context = nullptr; + atb::CreateContext(&context); + return context; +} + +void ReleaseDefaultContext(uint32_t deviceId, atb::Context *context) +{ + (void)deviceId; + DestroyContext(context); +} \ No newline at end of file diff --git a/tests/proftest/utils/src/tensor_utils.cpp b/tests/proftest/utils/src/tensor_utils.cpp new file mode 100644 index 00000000..7c779a22 --- /dev/null +++ b/tests/proftest/utils/src/tensor_utils.cpp @@ -0,0 +1,627 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "tensor_utils.h" +#include "type_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +size_t GetDataItemSize(aclDataType dtype) +{ + switch (dtype) { + case ACL_DT_UNDEFINED: + return sizeof(bool); + case ACL_BOOL: + return sizeof(bool); + case ACL_FLOAT: + return sizeof(float); + case ACL_FLOAT16: + return sizeof(uint16_t); + case ACL_INT8: + return sizeof(int8_t); + case ACL_INT16: + return sizeof(int16_t); + case ACL_INT32: + return sizeof(int32_t); + case ACL_INT64: + return sizeof(int64_t); + case ACL_UINT8: + return sizeof(uint8_t); + case ACL_UINT16: + return sizeof(uint16_t); + case ACL_UINT32: + return sizeof(uint32_t); + case ACL_UINT64: + return sizeof(uint64_t); + case ACL_BF16: + return sizeof(uint16_t); + case ACL_DOUBLE: + return sizeof(double); + default: + return 0; + } +} + +static std::mt19937 gen(0); + +template T random_float(float min, float max) +{ + std::uniform_real_distribution dist(min, max); + return dist(gen); +} + + +template T random_int(float min, float max) +{ + int32_t min_int32 = static_cast(std::round(min)); + int32_t max_int32 = static_cast(std::round(max)); + std::uniform_int_distribution dist(min_int32, max_int32); + return dist(gen); +} + +template T random_uint(float min, float max) +{ + int32_t min_int32 = static_cast(std::round(min)); + int32_t max_int32 = static_cast(std::round(max)); + min_int32 = min_int32 < 0 ? 0 : min_int32; + max_int32 = max_int32 < 0 ? 0 : max_int32; + std::uniform_int_distribution dist(min_int32, max_int32); + return dist(gen); +} + +bool random_bool() +{ + std::uniform_int_distribution dist(0, 1); + return dist(gen); +} + +atb::Tensor FillTensorDataRandomly(const atb::TensorDesc &desc, float range_min, float range_max) +{ + atb::Tensor tensor{desc, nullptr, nullptr, 0}; + tensor.dataSize = atb::Utils::GetTensorSize(desc); + aclrtMallocHost((void **)&tensor.hostData, tensor.dataSize); + { + size_t dataItemSize = GetDataItemSize(desc.dtype); + uint64_t tensorNumel = atb::Utils::GetTensorNumel(desc); + void *basePtr = static_cast(tensor.hostData); + for (uint64_t i = 0; i < tensorNumel; ++i) { + void *elementPtr = static_cast(basePtr) + i * dataItemSize; + switch (desc.dtype) { + case ACL_FLOAT: + *static_cast(elementPtr) = random_float(range_min, range_max); + break; + case ACL_DOUBLE: + *static_cast(elementPtr) = random_float(range_min, range_max); + break; + case ACL_INT8: + *static_cast(elementPtr) = random_int(range_min, range_max); + break; + case ACL_INT16: + *static_cast(elementPtr) = random_int(range_min, range_max); + break; + case ACL_INT32: + *static_cast(elementPtr) = random_int(range_min, range_max); + break; + case ACL_INT64: + *static_cast(elementPtr) = random_int(range_min, range_max); + break; + case ACL_UINT8: + *static_cast(elementPtr) = random_uint(range_min, range_max); + break; + case ACL_UINT16: + *static_cast(elementPtr) = random_uint(range_min, range_max); + break; + case ACL_UINT32: + *static_cast(elementPtr) = random_uint(range_min, range_max); + break; + case ACL_UINT64: + *static_cast(elementPtr) = random_uint(range_min, range_max); + break; + case ACL_BOOL: + *static_cast(elementPtr) = random_bool(); + break; + case ACL_FLOAT16: + *static_cast(elementPtr) = FloatToFloat16(random_float(range_min, range_max)); + break; + case ACL_BF16: + *static_cast(elementPtr) = FloatToBfloat16(random_float(range_min, range_max)); + break; + default: + break; + } + } + } + aclrtMalloc((void **)&tensor.deviceData, tensor.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMemcpy(tensor.deviceData, tensor.dataSize, tensor.hostData, tensor.dataSize, ACL_MEMCPY_HOST_TO_DEVICE); + + return tensor; +} + +atb::Tensor FillTensorDataRandomly(const atb::TensorDesc &desc) +{ + atb::Tensor tensor = FillTensorDataRandomly(desc, -5, 5); + return tensor; +} + +atb::Tensor FillTensorDataRandomly(const atb::TensorDesc &desc, const std::pair range) +{ + atb::Tensor tensor = FillTensorDataRandomly(desc, range.first, range.second); + return tensor; +} + +std::vector FillTensorDataRandomly(const std::vector &descs) +{ + std::vector tensors; + for (const atb::TensorDesc &desc : descs) { + atb::Tensor tensor = FillTensorDataRandomly(desc); + tensors.push_back(tensor); + } + + return tensors; +} + +std::vector FillTensorDataRandomly(const std::vector &descs, + const std::vector &range_mins, + const std::vector &range_maxs) +{ + std::vector tensors; + if (range_mins.size() != range_maxs.size()) { + std::cout << "range_mins.size() != range_maxs.size()" << std::endl; + return tensors; + } + if (descs.size() < range_mins.size()) { + std::cout << "descs.size() < ranges.size(), The range in the back will be discarded" << std::endl; + } else if (descs.size() > range_mins.size()) { + std::cout << "descs.size() > ranges.size(), The tensor in the back will be filled with zero" << std::endl; + } + + for (size_t i = 0; i < descs.size(); ++i) { + if (i < range_mins.size()) { + atb::Tensor tensor = FillTensorDataRandomly(descs[i], range_mins[i], range_maxs[i]); + tensors.push_back(tensor); + } else { + atb::Tensor tensor = FillTensorDataByZero(descs[i]); + tensors.push_back(tensor); + } + } + return tensors; +} + +std::vector FillTensorDataRandomly(const std::vector &descs, float range_min, + float range_max) +{ + std::vector tensors; + for (size_t i = 0; i < descs.size(); ++i) { + atb::Tensor tensor = FillTensorDataRandomly(descs[i], range_min, range_max); + tensors.push_back(tensor); + } + return tensors; +} + +std::vector FillTensorDataRandomly(const std::vector &descs, + const std::vector> &ranges) +{ + std::vector tensors; + if (descs.size() < ranges.size()) { + std::cout << "descs.size() < ranges.size(), The range in the back will be discarded" << std::endl; + } else if (descs.size() > ranges.size()) { + std::cout << "descs.size() > ranges.size(), The tensor in the back will be filled with zero" << std::endl; + } + + for (size_t i = 0; i < descs.size(); ++i) { + if (i < ranges.size()) { + atb::Tensor tensor = FillTensorDataRandomly(descs[i], ranges[i]); + tensors.push_back(tensor); + } else { + atb::Tensor tensor = FillTensorDataByZero(descs[i]); + tensors.push_back(tensor); + } + } + return tensors; +} + +std::vector FillTensorDataRandomly(const std::vector &descs, + std::pair range) +{ + std::vector tensors; + for (size_t i = 0; i < descs.size(); ++i) { + atb::Tensor tensor = FillTensorDataRandomly(descs[i], range); + tensors.push_back(tensor); + } + return tensors; +} + +atb::Tensor FillTensorDataByZero(const atb::TensorDesc &desc) +{ + atb::Tensor tensor{desc, nullptr, nullptr, 0}; + tensor.dataSize = atb::Utils::GetTensorSize(desc); + aclrtMallocHost((void **)&tensor.hostData, tensor.dataSize); + { + size_t dataItemSize = GetDataItemSize(desc.dtype); + uint64_t tensorNumel = atb::Utils::GetTensorNumel(desc); + void *basePtr = static_cast(tensor.hostData); + for (uint64_t i = 0; i < tensorNumel; ++i) { + void *elementPtr = static_cast(basePtr) + i * dataItemSize; + switch (desc.dtype) { + case ACL_FLOAT: + *static_cast(elementPtr) = 0.0f; + break; + case ACL_DOUBLE: + *static_cast(elementPtr) = 0.0; + break; + case ACL_INT8: + *static_cast(elementPtr) = 0; + break; + case ACL_INT16: + *static_cast(elementPtr) = 0; + break; + case ACL_INT32: + *static_cast(elementPtr) = 0; + break; + case ACL_INT64: + *static_cast(elementPtr) = 0; + break; + case ACL_UINT8: + *static_cast(elementPtr) = 0; + break; + case ACL_UINT16: + *static_cast(elementPtr) = 0; + break; + case ACL_UINT32: + *static_cast(elementPtr) = 0; + break; + case ACL_UINT64: + *static_cast(elementPtr) = 0; + break; + case ACL_BOOL: + *static_cast(elementPtr) = false; + break; + case ACL_FLOAT16: + *static_cast(elementPtr) = FloatToFloat16(0.0f); + break; + case ACL_BF16: + *static_cast(elementPtr) = FloatToBfloat16(0.0f); + break; + default: + break; + } + } + } + aclrtMalloc((void **)&tensor.deviceData, tensor.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMemcpy(tensor.deviceData, tensor.dataSize, tensor.hostData, tensor.dataSize, ACL_MEMCPY_HOST_TO_DEVICE); + return tensor; +} + +std::vector FillTensorDataByZero(const std::vector &descs) +{ + std::vector tensors; + for (const atb::TensorDesc &desc : descs) { + atb::Tensor tensor = FillTensorDataByZero(desc); + tensors.push_back(tensor); + } + + return tensors; +} + +atb::Tensor FillTensorDataByOne(const atb::TensorDesc &desc) +{ + atb::Tensor tensor{desc, nullptr, nullptr, 0}; + tensor.dataSize = atb::Utils::GetTensorSize(desc); + aclrtMallocHost((void **)&tensor.hostData, tensor.dataSize); + { + size_t dataItemSize = GetDataItemSize(desc.dtype); + uint64_t tensorNumel = atb::Utils::GetTensorNumel(desc); + void *basePtr = static_cast(tensor.hostData); + for (uint64_t i = 0; i < tensorNumel; ++i) { + void *elementPtr = static_cast(basePtr) + i * dataItemSize; + switch (desc.dtype) { + case ACL_FLOAT: + *static_cast(elementPtr) = 1.0f; + break; + case ACL_DOUBLE: + *static_cast(elementPtr) = 1.0; + break; + case ACL_INT8: + *static_cast(elementPtr) = 1; + break; + case ACL_INT16: + *static_cast(elementPtr) = 1; + break; + case ACL_INT32: + *static_cast(elementPtr) = 1; + break; + case ACL_INT64: + *static_cast(elementPtr) = 1; + break; + case ACL_UINT8: + *static_cast(elementPtr) = 1; + break; + case ACL_UINT16: + *static_cast(elementPtr) = 1; + break; + case ACL_UINT32: + *static_cast(elementPtr) = 1; + break; + case ACL_UINT64: + *static_cast(elementPtr) = 1; + break; + case ACL_BOOL: + *static_cast(elementPtr) = true; + break; + case ACL_FLOAT16: + *static_cast(elementPtr) = FloatToFloat16(1.0f); + break; + case ACL_BF16: + *static_cast(elementPtr) = FloatToBfloat16(1.0f); + break; + default: + break; + } + } + } + aclrtMalloc((void **)&tensor.deviceData, tensor.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMemcpy(tensor.deviceData, tensor.dataSize, tensor.hostData, tensor.dataSize, ACL_MEMCPY_HOST_TO_DEVICE); + + return tensor; +} + +std::vector FillTensorDataByOne(const std::vector &descs) +{ + std::vector tensors; + for (const atb::TensorDesc &desc : descs) { + atb::Tensor tensor = FillTensorDataByOne(desc); + tensors.push_back(tensor); + } + + return tensors; +} + +atb::Tensor FillTensorDataByFile(const atb::TensorDesc &desc, const std::string &filePath) +{ + atb::Tensor tensor{desc, nullptr, nullptr, 0}; + tensor.dataSize = atb::Utils::GetTensorSize(desc); + aclrtMallocHost((void **)&tensor.hostData, tensor.dataSize); + std::fstream file(filePath, std::ios::in | std::ios::binary | std::ios::ate); + size_t fileSize = file.tellg(); + file.seekg(0, std::ios::beg); + std::vector fileData(fileSize); + file.read(fileData.data(), fileSize); + file.close(); + size_t begin_offset = 0; + size_t data_start = 0; + const std::string end_marker = "$End=1"; + for (size_t i = 0; i < fileSize; ++i) { + if (fileData[i] == '\n') { + std::string line(fileData.data() + begin_offset, fileData.data() + i); + begin_offset = i + 1; + if (line.find(end_marker) != std::string::npos) { + data_start = i + 1; + break; + } + } + } + + size_t binary_size = fileSize - data_start; + if (binary_size == tensor.dataSize) { + aclrtMemcpy(tensor.hostData, tensor.dataSize, fileData.data() + data_start, tensor.dataSize, + ACL_MEMCPY_HOST_TO_HOST); + } + aclrtMalloc((void **)&tensor.deviceData, tensor.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMemcpy(tensor.deviceData, tensor.dataSize, tensor.hostData, tensor.dataSize, ACL_MEMCPY_HOST_TO_DEVICE); + + return tensor; +} + +std::vector FillTensorDataByFile(const std::vector &descs, + const std::vector &filePaths) +{ + std::vector tensors; + if (descs.size() < filePaths.size()) { + std::cout << "descs.size() < filePaths.size(), The filePath in the back will be discarded" << std::endl; + } else if (descs.size() > filePaths.size()) { + std::cout << "descs.size() > filePaths.size(), The tensor in the back will be filled with zero" << std::endl; + } + + for (size_t i = 0; i < descs.size(); ++i) { + if (i < filePaths.size()) { + atb::Tensor tensor = FillTensorDataByFile(descs[i], filePaths[i]); + tensors.push_back(tensor); + } else { + atb::Tensor tensor = FillTensorDataByZero(descs[i]); + tensors.push_back(tensor); + } + } + + return tensors; +} + +void FreeTensor(atb::Tensor &tensor) +{ + if (tensor.hostData != nullptr) { + aclrtFreeHost(tensor.hostData); + tensor.hostData = nullptr; + } + if (tensor.deviceData != nullptr) { + aclrtFree(tensor.deviceData); + tensor.deviceData = nullptr; + } +} + +void FreeTensor(std::vector &tensors) +{ + for (atb::Tensor &tensor : tensors) { + FreeTensor(tensor); + } +} + +std::string FormatPrintTensorData(aclDataType dtype, size_t dataItemSize, void *data, uint64_t dimNum, + const int64_t *dims, size_t offset = 0, uint64_t depth = 0) +{ + if (depth == dimNum) { + void *elementPtr = static_cast(data) + offset; + switch (dtype) { + case ACL_BOOL: + return *static_cast(elementPtr) ? "true" : "false"; + case ACL_FLOAT: + { + float val = *static_cast(elementPtr); + std::ostringstream oss; + oss << std::fixed << std::setprecision(6) << val; + return oss.str(); + } + case ACL_FLOAT16: + { + float16 val = *static_cast(elementPtr); + float fval = Float16ToFloat(val); + std::ostringstream oss; + oss << std::fixed << std::setprecision(6) << fval; + return oss.str(); + } + case ACL_INT8: + return std::to_string(*static_cast(elementPtr)); + case ACL_INT16: + return std::to_string(*static_cast(elementPtr)); + case ACL_INT32: + return std::to_string(*static_cast(elementPtr)); + case ACL_INT64: + return std::to_string(*static_cast(elementPtr)); + case ACL_UINT8: + return std::to_string(*static_cast(elementPtr)); + case ACL_UINT16: + return std::to_string(*static_cast(elementPtr)); + case ACL_UINT32: + return std::to_string(*static_cast(elementPtr)); + case ACL_UINT64: + return std::to_string(*static_cast(elementPtr)); + case ACL_BF16: + { + bfloat16 val = *static_cast(elementPtr); + float fval = Bfloat16ToFloat(val); + std::ostringstream oss; + oss << std::fixed << std::setprecision(6) << fval; + return oss.str(); + } + case ACL_DOUBLE: + { + double val = *static_cast(elementPtr); + std::ostringstream oss; + oss << std::fixed << std::setprecision(6) << val; + return oss.str(); + } + default: + return "unsupported"; + } + } + size_t stride = dataItemSize; + for (size_t i = depth + 1; i < dimNum; ++i) { + stride *= dims[i]; + } + std::ostringstream oss; + oss << "["; + for (uint64_t i = 0; i < dims[depth]; ++i) { + if (i > 0) { + if (depth == dimNum - 1) { + oss << ","; + } else { + oss << "\n"; + } + } + oss << FormatPrintTensorData(dtype, dataItemSize, data, dimNum, dims, offset + i * stride, depth + 1); + } + oss << "]"; + return oss.str(); +} + + +std::string GetDataTypeString(aclDataType dtype) +{ + switch (dtype) { + case ACL_DT_UNDEFINED: + return "ACL_DT_UNDEFINED"; + case ACL_BOOL: + return "ACL_BOOL"; + case ACL_FLOAT: + return "ACL_FLOAT"; + case ACL_FLOAT16: + return "ACL_FLOAT16"; + case ACL_INT8: + return "ACL_INT8"; + case ACL_INT16: + return "ACL_INT16"; + case ACL_INT32: + return "ACL_INT32"; + case ACL_INT64: + return "ACL_INT64"; + case ACL_UINT8: + return "ACL_UINT8"; + case ACL_UINT16: + return "ACL_UINT16"; + case ACL_UINT32: + return "ACL_UINT32"; + case ACL_UINT64: + return "ACL_UINT64"; + case ACL_BF16: + return "ACL_BF16"; + case ACL_DOUBLE: + return "ACL_DOUBLE"; + default: + return ""; + } +} +std::string FormatPrintTensorDesc(const atb::Tensor &tensor) +{ + std::ostringstream oss; + oss << "aclDataType: " << GetDataTypeString(tensor.desc.dtype) << ", dim: ["; + for (size_t i = 0; i < tensor.desc.shape.dimNum; ++i) { + oss << std::to_string(tensor.desc.shape.dims[i]); + if (i != tensor.desc.shape.dimNum - 1) { + oss << ","; + } + } + oss << "]"; + return oss.str(); +} +void PrintDeviceTensor(const atb::Tensor &tensor) +{ + if (tensor.deviceData == nullptr) { + std::cout << "tensor's daviceData == nullptr, no print\n"; + return; + } + void *hostData; + aclrtMallocHost((void **)&hostData, tensor.dataSize); + aclrtMemcpy(hostData, tensor.dataSize, tensor.deviceData, tensor.dataSize, ACL_MEMCPY_DEVICE_TO_HOST); + { + size_t dataItemSize = GetDataItemSize(tensor.desc.dtype); + std::cout << FormatPrintTensorData(tensor.desc.dtype, dataItemSize, hostData, tensor.desc.shape.dimNum, + tensor.desc.shape.dims) + << " " << FormatPrintTensorDesc(tensor) << std::endl; + } + aclrtFreeHost(hostData); +} + + +void PrintDeviceTensor(const std::vector &tensors) +{ + for (const atb::Tensor &tensor : tensors) { + PrintDeviceTensor(tensor); + } +} diff --git a/tests/proftest/utils/src/type_utils.cpp b/tests/proftest/utils/src/type_utils.cpp new file mode 100644 index 00000000..e0dda3ef --- /dev/null +++ b/tests/proftest/utils/src/type_utils.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "type_utils.h" +#include +#include +#include +float16 FloatToFloat16(float fp32) +{ + if (fp32 == 0.0f) { + return (std::signbit(fp32) ? 0x8000 : 0x0000); + } + + uint32_t float_bits; + static_assert(sizeof(float) == sizeof(uint32_t), "Float size mismatch"); + std::memcpy(&float_bits, &fp32, sizeof(float)); + + const uint32_t sign = (float_bits >> 31) & 0x1; + const uint32_t exp = (float_bits >> 23) & 0xFF; + const uint32_t mant = float_bits & 0x7FFFFF; + if (exp == 0xFF) { + if (mant == 0) { + return (sign << 15) | 0x7C00; + } else { + return (sign << 15) | 0x7C00 | (mant >> 13); + } + } + + int32_t exp_fp16 = static_cast(exp) - 127 + 15; + if (exp_fp16 <= 0) { + return (sign << 15); + } + + if (exp_fp16 >= 0x1F) { + return (sign < 15) | 0x7C00; + } + + uint32_t mant24 = (1 << 23) | mant; + uint32_t round_bits = mant24 & 0x1FFF; + uint32_t base = (mant24 >> 13) & 0x3FF; + + if (round_bits > 0x1000 || (round_bits == 0x1000 && (base & 1))) { + base++; + if (base > 0xFF) { + base = 0; + exp_fp16++; + if (exp_fp16 >= 0x1F) { + return (sign << 15) | 0x7C00; + } + } + } + + return (sign << 15) | (exp_fp16 << 10) | base; +} + +bfloat16 FloatToBfloat16(float fp32) +{ + if (fp32 == 0.0f) { + return (std::signbit(fp32) ? 0x8000 : 0x0000); + } + + uint32_t float_bits; + static_assert(sizeof(float) == sizeof(uint32_t), "Float size mismatch"); + std::memcpy(&float_bits, &fp32, sizeof(float)); + + bfloat16 bfloat16_bits = static_cast(float_bits >> 16); + + const uint32_t exp = (float_bits >> 23) & 0xFF; + const uint32_t mant = float_bits & 0x7FFFFF; + if (exp == 0xFF && mant != 0) { + bfloat16_bits |= 0x01; + } + + return bfloat16_bits; +} + +float Float16ToFloat(float16 fp16) +{ + const uint32_t sign = (fp16 >> 15) & 0x1; + const uint32_t exp_f16 = (fp16 >> 10) & 0x1F; + const uint32_t mant_f16 = fp16 & 0x3FF; + if (exp_f16 == 0x1F) { + uint32_t inf_nan = (sign << 31) | 0x7F800000 | (mant_f16 << 13); + float result; + memcpy(&result, &inf_nan, sizeof(float)); + return result; + } + if (exp_f16 == 0) { + if (mant_f16 == 0) { + uint32_t sign_bit = sign << 31; + float result; + memcpy(&result, &sign_bit, sizeof(float)); + return result; + } else { + uint32_t shift = 0; + uint32_t mant = mant_f16; + while ((mant & 0x4000) == 0) { + mant <<= 1; + ++shift; + } + const int32_t exp_float = -14 - shift + 1 + 127; + const uint32_t mant_float = (mant & 0x3FF) << 13; + uint32_t combined = (sign << 31) | (exp_float << 23) | mant_float; + float result; + memcpy(&result, &combined, sizeof(float)); + return result; + } + } + + const uint32_t exp_float = exp_f16 + 112; + const uint32_t mant_float = mant_f16 << 13; + uint32_t combined = (sign << 31) | (exp_float << 23) | mant_float; + float result; + memcpy(&result, &combined, sizeof(float)); + return result; +} + +float Bfloat16ToFloat(bfloat16 bf16) +{ + uint32_t float_bits = static_cast(bf16) << 16; + float result; + memcpy(&result, &float_bits, sizeof(float)); + return result; +} \ No newline at end of file -- Gitee From 2f505f9d0da1f3772ff36c6786ed3328391c781c Mon Sep 17 00:00:00 2001 From: zouyanlong Date: Thu, 14 Aug 2025 11:54:24 +0800 Subject: [PATCH 2/2] fix --- tests/proftest/main.cpp | 11 +++++++---- tests/proftest/test_cases/bloom_7b/main.cpp | 20 +++++++++++++------- tests/proftest/utils/src/tensor_utils.cpp | 21 +++++++++++++++------ 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/tests/proftest/main.cpp b/tests/proftest/main.cpp index f6789d8a..5ebb3ba2 100644 --- a/tests/proftest/main.cpp +++ b/tests/proftest/main.cpp @@ -49,11 +49,14 @@ public: if (!baselineFile.is_open()) { std::cerr << "Failed to open baseline file: " << baselinePath << std::endl; exit(1); + } else { + std::cout << caseName << " baseline file path:" << baselinePath << std::endl; } baselineFile << "Benchmark_Name,Real Time (s),CPU Time (s)\n"; - baselineFile << caseName << "," << report.real_accumulated_time << "," << report.cpu_accumulated_time << "\n"; + baselineFile << caseName << "," << report.real_accumulated_time << "," << report.cpu_accumulated_time + << "\n"; baselineFile.close(); - return; + continue; } std::filesystem::create_directories(resultPath.parent_path()); @@ -70,7 +73,7 @@ public: std::ifstream baselineFile(baselinePath, std::ios::in); if (!baselineFile.is_open()) { std::cerr << "Failed to open baseline file: " << baselinePath << " , skip compare.\n"; - return; + exit(1); } double baselineCpuTime = 0; std::string line; @@ -94,7 +97,7 @@ public: } } } - + benchmark::ConsoleReporter::ReportRuns(reports); return; } diff --git a/tests/proftest/test_cases/bloom_7b/main.cpp b/tests/proftest/test_cases/bloom_7b/main.cpp index c3393b9e..10988449 100644 --- a/tests/proftest/test_cases/bloom_7b/main.cpp +++ b/tests/proftest/test_cases/bloom_7b/main.cpp @@ -24,15 +24,12 @@ #include "context_utils.h" #include "tensor_utils.h" - -enum class TensorId : int32_t { - ElewiseAddNode0In, -}; #include "models/base/param/param.h" #include "models/base/param/layer_param.h" #include "models/bloom/layer/bloom_decoder_layer.h" static void Layer_Bloom_7B(benchmark::State &state) { + int round = 0; for (auto _ : state) { state.PauseTiming(); aclInit(nullptr); @@ -177,7 +174,12 @@ static void Layer_Bloom_7B(benchmark::State &state) {.dtype = ACL_INT32, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1, 1}, .dimNum = 2}}, {.dtype = ACL_INT32, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1}, .dimNum = 1}}, }; - std::vector inTensor = FillTensorDataByOne(inTensorDesc); + std::vector filePath; + for (int i = 0; i < 61; ++i) { + filePath.push_back("/home/zouyanlong/msit_dump/tensors/0_15072/0/2_Prefill_layer/before/intensor" + + std::to_string(i) + ".bin"); + } + std::vector inTensor = FillTensorDataByFile(inTensorDesc, filePath); std::vector outTensorDesc{ {.dtype = ACL_FLOAT16, .format = ACL_FORMAT_ND, .shape = atb::Dims{.dims = {1, 4096}, .dimNum = 2}}, }; @@ -193,7 +195,10 @@ static void Layer_Bloom_7B(benchmark::State &state) graphOperation->Setup(variantPack, workwpaceSize, context); void *workSpace = nullptr; aclrtMalloc(&workSpace, workwpaceSize, ACL_MEM_MALLOC_HUGE_FIRST); - state.ResumeTiming(); + if (round < 20) + round++; + else + state.ResumeTiming(); graphOperation->Execute(variantPack, (uint8_t *)workSpace, workwpaceSize, context); aclrtSynchronizeStream(stream); state.PauseTiming(); @@ -213,7 +218,8 @@ static void Layer_Bloom_7B(benchmark::State &state) aclrtDestroyStream(stream); aclrtResetDevice(deviceId); aclFinalize(); + state.ResumeTiming(); } } -BENCHMARK(Layer_Bloom_7B)->Iterations(1); +BENCHMARK(Layer_Bloom_7B)->Iterations(100); diff --git a/tests/proftest/utils/src/tensor_utils.cpp b/tests/proftest/utils/src/tensor_utils.cpp index 7c779a22..9259dc04 100644 --- a/tests/proftest/utils/src/tensor_utils.cpp +++ b/tests/proftest/utils/src/tensor_utils.cpp @@ -397,10 +397,11 @@ std::vector FillTensorDataByOne(const std::vector atb::Tensor FillTensorDataByFile(const atb::TensorDesc &desc, const std::string &filePath) { - atb::Tensor tensor{desc, nullptr, nullptr, 0}; - tensor.dataSize = atb::Utils::GetTensorSize(desc); - aclrtMallocHost((void **)&tensor.hostData, tensor.dataSize); std::fstream file(filePath, std::ios::in | std::ios::binary | std::ios::ate); + if (!file.is_open()) { + std::cerr << "Can't open: " << filePath << std::endl; + exit(1); + } size_t fileSize = file.tellg(); file.seekg(0, std::ios::beg); std::vector fileData(fileSize); @@ -421,10 +422,17 @@ atb::Tensor FillTensorDataByFile(const atb::TensorDesc &desc, const std::string } size_t binary_size = fileSize - data_start; - if (binary_size == tensor.dataSize) { - aclrtMemcpy(tensor.hostData, tensor.dataSize, fileData.data() + data_start, tensor.dataSize, - ACL_MEMCPY_HOST_TO_HOST); + atb::Tensor tensor{desc, nullptr, nullptr, 0}; + tensor.dataSize = atb::Utils::GetTensorSize(desc); + if (binary_size < tensor.dataSize) { + std::cerr << "binary_size < tensor.dataSize" << "\n" + << "filePath:" << filePath << "\n" + << "binary_size: " << binary_size << " tensor.dataSize: " << tensor.dataSize << std::endl; + exit(1); } + aclrtMallocHost((void **)&tensor.hostData, tensor.dataSize); + aclrtMemcpy(tensor.hostData, tensor.dataSize, fileData.data() + data_start, tensor.dataSize, + ACL_MEMCPY_HOST_TO_HOST); aclrtMalloc((void **)&tensor.deviceData, tensor.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); aclrtMemcpy(tensor.deviceData, tensor.dataSize, tensor.hostData, tensor.dataSize, ACL_MEMCPY_HOST_TO_DEVICE); @@ -600,6 +608,7 @@ std::string FormatPrintTensorDesc(const atb::Tensor &tensor) oss << "]"; return oss.str(); } + void PrintDeviceTensor(const atb::Tensor &tensor) { if (tensor.deviceData == nullptr) { -- Gitee