From bb7bc9d5a822d77ffcadd59b4635e071251117cf Mon Sep 17 00:00:00 2001 From: lijs77 Date: Fri, 9 May 2025 16:35:59 +0800 Subject: [PATCH 1/4] ms-kvscale-modified --- .../ascend/kernel/internal/kv_scale_cache.cc | 37 +++++ .../ascend/kernel/internal/kv_scale_cache.h | 30 ++++ .../ops/infer/ops_func_impl/kv_scale_cache.cc | 128 ++++++++++++++++++ .../ops/infer/ops_func_impl/kv_scale_cache.h | 48 +++++++ .../ops/op_def/yaml/kv_scale_cache_op.yaml | 21 +++ .../mindspore/ops/operations/__init__.py | 3 +- .../python/mindspore/ops/operations/nn_ops.py | 2 +- tests/ut/cpp/ops/test_ops_kv_scale_cache.cc | 77 +++++++++++ 8 files changed, 344 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/internal/kv_scale_cache.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/internal/kv_scale_cache.h create mode 100644 mindspore/ops/infer/ops_func_impl/kv_scale_cache.cc create mode 100644 mindspore/ops/infer/ops_func_impl/kv_scale_cache.h create mode 100644 mindspore/ops/op_def/yaml/kv_scale_cache_op.yaml create mode 100644 tests/ut/cpp/ops/test_ops_kv_scale_cache.cc diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/kv_scale_cache.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/kv_scale_cache.cc new file mode 100644 index 00000000000..2bd90c209cd --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/kv_scale_cache.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/kernel/internal/kv_scale_cache.h" +#include +#include "common/kernel.h" +#include "plugin/device/ascend/kernel/internal/internal_kernel_in_out_map.h" + +namespace mindspore { +namespace kernel { +internal::InternalOpPtr InternalKvScaleCache::CreateKernel(const internal::InputsImmutableInfoList &inputs_ii, + const internal::OutputsImmutableInfoList &outputs_ii, + const std::vector &ms_inputs, + const std::vector &ms_outputs) { + internal::KvScaleCacheParam param; + param.cache_mode = static_cast(ms_inputs[kIndex4]->GetValue().value()); + MS_LOG(INFO) << "Create kernel: " << internal::kInternalKvScaleCacheOpName << " cache_mode: " << param.cache_mode; + return internal::CreateKvScaleCacheOp(inputs_ii, outputs_ii, param, internal::kInternalKvScaleCacheOpName); +} +MS_INTERNAL_KERNEL_FACTORY_REG(KvScaleCache, internal::kInternalKvScaleCacheOpName, InternalKvScaleCache); +REG_MS_TO_INTERNAL_IN_TENSOR_IDX_MAP(KvScaleCache, INPUT_NUM_4, INDEX_0, INDEX_1, INDEX_3, INDEX_2); + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/kv_scale_cache.h b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/kv_scale_cache.h new file mode 100644 index 00000000000..97fc8741a91 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/kv_scale_cache.h @@ -0,0 +1,30 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_INTERNAL_KV_SCALE_CACHE_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_INTERNAL_KV_SCALE_CACHE_H_ + +#include +#include +#include +#include "plugin/device/ascend/kernel/internal/internal_kernel_mod.h" +#include "include/internal.h" + +namespace mindspore { +namespace kernel { +DECLARE_INTERNAL_KERNEL_MOD(KvScaleCache) +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_INTERNAL_KV_SCALE_CACHE_H_ \ No newline at end of file diff --git a/mindspore/ops/infer/ops_func_impl/kv_scale_cache.cc b/mindspore/ops/infer/ops_func_impl/kv_scale_cache.cc new file mode 100644 index 00000000000..d460b4a4c29 --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/kv_scale_cache.cc @@ -0,0 +1,128 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "infer/ops_func_impl/kv_scale_cache.h" +#include +#include +#include +#include "utils/check_convert_utils.h" +#include "utils/ms_context.h" +#include "utils/convert_utils_base.h" +#include "mindspore/ops/ops_utils/op_utils.h" + +namespace mindspore { +namespace ops { +namespace { +static constexpr int32_t prefill_mode = 1; +static constexpr int32_t incremental_mode = 0; +} +BaseShapePtr KvScaleCacheFuncImpl::InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const { + auto op_name = primitive->name(); + + auto key_scale_shape_ptr = input_args[KvScaleCacheInputKeyScaleIndex]->GetShape(); + auto value_scale_shape_ptr = input_args[KvScaleCacheInputValueScaleIndex]->GetShape(); + MS_EXCEPTION_IF_NULL(key_scale_shape_ptr); + MS_EXCEPTION_IF_NULL(value_scale_shape_ptr); + + auto key_scale_cache_shape_ptr = input_args[KvScaleCacheInputKeyValueScaleCacheIndex]->GetShape(); + MS_EXCEPTION_IF_NULL(key_scale_cache_shape_ptr); + auto key_scale_cache_shape = key_scale_cache_shape_ptr->GetShapeVector(); + + auto batch_valid_shape_ptr = input_args[KvScaleCacheInputBatchVaildLengthIndex]->GetShape(); + MS_EXCEPTION_IF_NULL(batch_valid_shape_ptr); + auto batch_valid_shape = batch_valid_shape_ptr->GetShapeVector(); + + if (IsDynamicRank(key_scale_shape_ptr->GetShapeVector()) || IsDynamicRank(batch_valid_shape_ptr->GetShapeVector())) { + return std::make_shared(ShapeVector{abstract::Shape::kShapeRankAny}); + } + + const int64_t input_num_dims = 2; + MS_CHECK_VALUE(key_scale_shape_ptr->GetShapeVector().size() == input_num_dims, + CheckAndConvertUtils::FormatCommMsg("rank of kscale must be 2, but got:", + key_scale_shape_ptr->GetShapeVector().size())); + MS_CHECK_VALUE(value_scale_shape_ptr->GetShapeVector().size() == input_num_dims, + CheckAndConvertUtils::FormatCommMsg("rank of vscale must be 2, but got:", + value_scale_shape_ptr->GetShapeVector().size())); + const size_t batch_valid_size = batch_valid_shape.size(); + (void)CheckAndConvertUtils::CheckInteger(batch_valid_size + "batch_valid_size must be greater than 0, but got:", batch_valid_size, kGreaterEqual, 0, + op_name); + + if (!IsDynamic(key_scale_cache_shape) && !IsDynamic(batch_valid_shape)) { + const size_t key_scale_cache_dim = key_scale_cache_shape[0]; + const size_t max_batch_size = key_scale_cache_shape[1]; + // max_batch_size 约束 + MS_CHECK_VALUE(batch_valid_size <= max_batch_size, CheckAndConvertUtils::FormatCommMsg("The batch_size must not bigger than max_batch_size, but got batch_valid_size: ", batch_valid_size, ", max_batch_size: ", max_batch_size)); + MS_CHECK_VALUE(key_scale_cache_dim == input_num_dims, CheckAndConvertUtils::FormatCheckIntegerMsg("key_scale_cache_dim", SizeToLong(key_scale_cache_dim), kEqual, 2, primitive)); + MS_CHECK_VALUE(max_batch_size != 0, CheckAndConvertUtils::FormatCheckIntegerMsg("max_batch_size", SizeToLong(max_batch_size), kNotEqual, 0, primitive)); + // max_seqlens约束 + const size_t max_seqlens = key_scale_cache_shape[2]; + MS_CHECK_VALUE(max_seqlens != 0, CheckAndConvertUtils::FormatCheckIntegerMsg("max_seqlens", SizeToLong(max_seqlens), kNotEqual, 0, primitive)); + auto batch_valid_tensor = input_args[KvScaleCacheInputBatchVaildLengthIndex]; + //获取 batch_valid_length 的最大值 + if (batch_valid_tensor->GetValue() != nullptr) { + auto shape_ptr = batch_valid_tensor->GetShape()->cast(); + MS_EXCEPTION_IF_NULL(shape_ptr); + const auto &shape = shape_ptr->shape(); + auto max_value = *std::max_element(shape.begin(), shape.end()); + MS_CHECK_VALUE(max_value <= static_cast(max_seqlens), CheckAndConvertUtils::FormatCommMsg( + "Max seqlen in batch exceeds limit:", max_value, + " > max_seqlens:", max_seqlens)); + } + } + + // decode-check + auto cache_mode_scalar = GetScalarValue(input_args[KvScaleCacheInputCacheModeIndex]->GetValue()); + size_t decode_batch = key_scale_shape_ptr->GetShapeVector()[0]; + size_t seqlens = key_scale_shape_ptr->GetShapeVector()[1]; + if (cache_mode_scalar.has_value()) { + auto cache_mode = static_cast(cache_mode_scalar.value()); + MS_LOG(INFO) << "cache_mode: " << cache_mode; + if (cache_mode != incremental_mode && cache_mode != prefill_mode && cache_mode != -1){ + MS_LOG(EXCEPTION) << "this cache_mode is not supported, but got cache_mode: " << cache_mode; + } + if (cache_mode == incremental_mode) { + MS_CHECK_VALUE((decode_batch >= batch_valid_size) && (seqlens == 1), + CheckAndConvertUtils::FormatCommMsg( + "For ", op_name, + ", decode_batch must be more than or equal to batch_valid_size, seqlens must be 1, but got decode_batch: ", decode_batch, ", batch_valid_size: ", batch_valid_size, "seqlens: ", seqlens) + ); + } + } + + auto shape_element = key_scale_cache_shape_ptr->cast(); + return shape_element; +} + +TypePtr KvScaleCacheFuncImpl::InferType(const PrimitivePtr &primitive, + const std::vector &input_args) const { + const std::set valid_types = {kFloat32}; + auto op_name = primitive->name(); + std::map types; + + (void)types.emplace("key_scale", input_args[KvScaleCacheInputKeyScaleIndex]->GetType()); + (void)types.emplace("value_scale", input_args[KvScaleCacheInputValueScaleIndex]->GetType()); + (void)types.emplace("key_value_scale_cache", input_args[KvScaleCacheInputKeyValueScaleCacheIndex]->GetType()); + auto type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name); + + auto bvl_type = input_args[KvScaleCacheInputBatchVaildLengthIndex]->GetType(); + const std::set int32_valid_types = {kInt32}; + std::map int32_types; + (void)int32_types.emplace("batch_valid_length", bvl_type); + (void)CheckAndConvertUtils::CheckTensorTypeValid("bvl", bvl_type, int32_valid_types, op_name); + return type; +} +} // namespace ops +} // namespace mindspore diff --git a/mindspore/ops/infer/ops_func_impl/kv_scale_cache.h b/mindspore/ops/infer/ops_func_impl/kv_scale_cache.h new file mode 100644 index 00000000000..5926cfbc950 --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/kv_scale_cache.h @@ -0,0 +1,48 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_KV_SCALE_CACHE_H_ +#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_KV_SCALE_CACHE_H_ +#include +#include +#include +#include +#include +#include +#include "ops/base_operator.h" +#include "mindapi/base/types.h" +#include "ops/ops_func_impl/op_func_impl.h" + +namespace mindspore { +namespace ops { +enum KvScaleCacheInputIndex : size_t { + KvScaleCacheInputKeyScaleIndex, + KvScaleCacheInputValueScaleIndex, + KvScaleCacheInputKeyValueScaleCacheIndex, + KvScaleCacheInputBatchVaildLengthIndex, + KvScaleCacheInputCacheModeIndex, + kvScaleCacheInputsNum +}; + +class OPS_API KvScaleCacheFuncImpl : public OpFuncImpl { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const override; + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override; +}; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_KV_SCALE_CACHE_H_ diff --git a/mindspore/ops/op_def/yaml/kv_scale_cache_op.yaml b/mindspore/ops/op_def/yaml/kv_scale_cache_op.yaml new file mode 100644 index 00000000000..56a0214b9c8 --- /dev/null +++ b/mindspore/ops/op_def/yaml/kv_scale_cache_op.yaml @@ -0,0 +1,21 @@ +#operator kv_scale_cache +kv_scale_cache: + args: + key_scale: + dtype: tensor + value_scale: + dtype: tensor + key_value_scale_cache: + dtype: tensor + batch_valid_length: + dtype: tensor + cache_mode: + dtype: int + args_signature: + rw_write: key_value_scale_cache + dtype_group: (key_scale, value_scale) + labels: + side_effect_mem: True + returns: + out: + dtype: tensor diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index cb746ad4477..82aa06eae36 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -118,7 +118,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa Dilation2D, DataFormatVecPermute, DeformableOffsets, Dense, FractionalAvgPool, FractionalMaxPool, FractionalMaxPool3DWithFixedKsize, FractionalMaxPoolWithFixedKsize, GridSampler2D, TripletMarginLoss, UpsampleNearest3D, UpsampleTrilinear3D, PadV3, ChannelShuffle, - GLU, MaxUnpool3D, Pdist, RmsNorm, PagedAttention, PagedAttentionMask, ReshapeAndCache, + GLU, MaxUnpool3D, Pdist, RmsNorm, PagedAttention, PagedAttentionMask, ReshapeAndCache, KvScaleCache, ApplyRotaryPosEmb, GroupTopk) from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, ConfusionMatrix, UpdateState, Load, StopGradient, Reusing, @@ -714,6 +714,7 @@ __all__ = [ "PagedAttention", "PagedAttentionMask", "ReshapeAndCache", + "KvScaleCache", "ApplyRotaryPosEmb", "GroupTopk", "RmsNorm", diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index 4ae45af670f..cab2594ef05 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -37,7 +37,7 @@ from ..auto_generate import (CeLU, Flatten, LogSoftmax, LogSoftmaxExt, GLU, ReLU Elu, Sigmoid, Softmax, SoftplusExt, HSwish, HSigmoid, AvgPool, BiasAdd, NLLLoss, OneHot, GeLU, FastGeLU, PReLU, RmsNorm, IncreFlashAttention, MSELossExt, GridSampler3D, GridSampler2D, LayerNorm, LayerNormExt, HShrink, AdamWeightDecay, Dropout, - ApplyRotaryPosEmb, GroupTopk, PagedAttention, PagedAttentionMask, ReshapeAndCache, + ApplyRotaryPosEmb, GroupTopk, PagedAttention, PagedAttentionMask, ReshapeAndCache, KvScaleCache, FlashAttentionScore, PromptFlashAttention, Embedding, UpsampleNearest1D, UpsampleNearest2D, UpsampleNearest3D, UpsampleTrilinear3D, SoftMarginLoss, UpsampleBilinear2D, UpsampleLinear1D, diff --git a/tests/ut/cpp/ops/test_ops_kv_scale_cache.cc b/tests/ut/cpp/ops/test_ops_kv_scale_cache.cc new file mode 100644 index 00000000000..cbaf9389da2 --- /dev/null +++ b/tests/ut/cpp/ops/test_ops_kv_scale_cache.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "common/common_test.h" +#include "ir/dtype/type.h" +#include "abstract/dshape.h" +#include "utils/tensor_construct_utils.h" +#include "ir/primitive.h" +#include "abstract/abstract_value.h" +#include "include/backend/optimizer/helper.h" +#include "ops/test_ops.h" +#include "infer/ops_func_impl/kv_scale_cache.h" +#include "ops/test_value_utils.h" + +namespace mindspore { +namespace ops { +struct KvScaleCacheShapeParams { + ShapeVector key_scale_shape; + TypePtr key_scale_type; + ShapeVector value_scale_shape; + TypePtr value_scale_type; + ShapeVector key_value_scale_cache_shape; + TypePtr key_value_scale_cache_type; + ShapeVector batch_valid_length_shape; + TypePtr batch_valid_length_type; + ValuePtr cache_mode; +}; + +class TestKvScaleCache : public TestOps, public testing::WithParamInterface {}; + +TEST_P(TestKvScaleCache, DynShape) { + const auto ¶m = GetParam(); + auto key_scale = std::make_shared(param.key_scale_type, param.key_scale_shape); + auto value_scale = std::make_shared(param.value_scale_type, param.value_scale_shape); + auto batch_valid_length = + std::make_shared(param.batch_valid_length_type, param.batch_valid_length_shape); + auto key_value_scale_cache = + std::make_shared(param.key_value_scale_cache_type, param.key_value_scale_cache_shape); + auto cache_mode = param.cache_mode->ToAbstract(); + auto key_value_scale_cache_shape = std::make_shared(param.key_value_scale_cache_shape); + auto expect_shape = key_value_scale_cache_shape; + auto expect_type = param.key_value_scale_cache_type; + + KvScaleCacheFuncImpl func_impl; + auto prim = std::make_shared("KvScaleCache"); + auto out_dtype = + func_impl.InferType(prim, {key_scale, value_scale, key_value_scale_cache, batch_valid_length, cache_mode}); + ASSERT_TRUE(*out_dtype == *expect_type); + auto out_shape = + func_impl.InferShape(prim, {key_scale, value_scale, key_value_scale_cache, batch_valid_length, cache_mode}); + ASSERT_TRUE(*out_shape == *expect_shape); +} + +INSTANTIATE_TEST_CASE_P( + TestKvScaleCache, TestKvScaleCache, + testing::Values( + KvScaleCacheShapeParams{ + {3, 4, 20}, kFloat32, {3, 4, 20}, kFloat32, {3, 4, 20}, kFloat32, {12}, kInt32, CreateScalar(1)}, + KvScaleCacheShapeParams{ + {3, 4, 20}, kFloat32, {3, 4, 20}, kFloat32, {3, 4, 20}, kFloat32, {12}, kInt32, CreateScalar(0)} + )); +} // namespace ops +} // namespace mindspore -- Gitee From 18649986187139caf10cc1df136b27b706ee1be0 Mon Sep 17 00:00:00 2001 From: yonibaehr Date: Tue, 13 May 2025 15:31:36 +0300 Subject: [PATCH 2/4] Ascend310P kernel fusion --- .../kernel/internal/internal_kernel_plugin.cc | 4 + .../ascend/kernel/internal/matmul_elemwise.cc | 3 + .../kernel/internal/multi_weight_matmul.cc | 67 ++++ .../kernel/internal/multi_weight_matmul.h | 2 + .../kernel/internal/quant_batch_matmul.cc | 7 + .../device/ascend/optimizer/CMakeLists.txt | 1 + .../inference_matmul_split_fusion.cc | 377 +++++++++++++++++- .../inference_matmul_split_fusion.h | 37 +- .../inference_qbmm_elemwise_fusion.cc | 140 +++++++ .../inference_qbmm_elemwise_fusion.h | 46 +++ .../ir_fusion_infer/matmul_elemwise_fusion.cc | 15 +- mindspore/core/utils/ms_context.cc | 9 +- ...matmul_split_silu_fastgelu_add_mul_out1.cc | 72 ++++ .../matmul_split_silu_fastgelu_add_mul_out1.h | 34 ++ .../matmul_split_silu_mul_out1.cc | 72 ++++ .../matmul_split_silu_mul_out1.h | 34 ++ ...matmul_split_silu_fastgelu_add_mul_out1.cc | 72 ++++ ..._matmul_split_silu_fastgelu_add_mul_out1.h | 34 ++ .../q_matmul_split_silu_mul_out1.cc | 73 ++++ .../q_matmul_split_silu_mul_out1.h | 34 ++ mindspore/ops/op_def/nn_op_name.h | 4 + ...l_split_silu_fastgelu_add_mul_out1_op.yaml | 16 + .../infer/matmul_split_silu_mul_out1_op.yaml | 16 + ...l_split_silu_fastgelu_add_mul_out1_op.yaml | 20 + .../q_matmul_split_silu_mul_out1_op.yaml | 20 + .../python/mindspore/ops/operations/nn_ops.py | 1 + 26 files changed, 1182 insertions(+), 28 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_qbmm_elemwise_fusion.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_qbmm_elemwise_fusion.h create mode 100644 mindspore/ops/infer/ops_func_impl/matmul_split_silu_fastgelu_add_mul_out1.cc create mode 100644 mindspore/ops/infer/ops_func_impl/matmul_split_silu_fastgelu_add_mul_out1.h create mode 100644 mindspore/ops/infer/ops_func_impl/matmul_split_silu_mul_out1.cc create mode 100644 mindspore/ops/infer/ops_func_impl/matmul_split_silu_mul_out1.h create mode 100644 mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_fastgelu_add_mul_out1.cc create mode 100644 mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_fastgelu_add_mul_out1.h create mode 100644 mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_mul_out1.cc create mode 100644 mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_mul_out1.h create mode 100644 mindspore/ops/op_def/yaml/infer/matmul_split_silu_fastgelu_add_mul_out1_op.yaml create mode 100644 mindspore/ops/op_def/yaml/infer/matmul_split_silu_mul_out1_op.yaml create mode 100644 mindspore/ops/op_def/yaml/infer/q_matmul_split_silu_fastgelu_add_mul_out1_op.yaml create mode 100644 mindspore/ops/op_def/yaml/infer/q_matmul_split_silu_mul_out1_op.yaml diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/internal_kernel_plugin.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/internal_kernel_plugin.cc index 6806a689e40..6fabc13916a 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/internal_kernel_plugin.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/internal_kernel_plugin.cc @@ -67,6 +67,10 @@ static std::unordered_map &ms_inputs, const std::vector &ms_outputs) override; const std::string op_name_{"UnknownOp"}; + bool split_two_{false}; + bool fused_{false}; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/quant_batch_matmul.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/quant_batch_matmul.cc index e3cd96db869..8cc7998dfd8 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/quant_batch_matmul.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/quant_batch_matmul.cc @@ -31,6 +31,13 @@ internal::InternalOpPtr InternalQuantBatchMatmul::CreateKernel(const internal::I param.with_bias = !(ms_inputs[kIndex4]->GetType()->isa()); param.enable_shuffle = false; // the real definition is in internal param.enable_dequant = true; + bool has_element_type = primitive_->HasAttr("ElemwiseType"); + auto value_str = primitive_->GetAttr("ElemwiseType"); + if (has_element_type && (value_str != nullptr)) { + if (GetValue(value_str) == "fastgelu") { + param.with_fastgelu = true; + } + } output_format_ = outputs[0].GetFormat(); return internal::CreateMatmulOp(inputs, outputs, param, internal::kInternalMatMulOpName); } diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/CMakeLists.txt b/mindspore/ccsrc/plugin/device/ascend/optimizer/CMakeLists.txt index 2ed5e37009d..6dd263909db 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/CMakeLists.txt +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/CMakeLists.txt @@ -54,6 +54,7 @@ file(GLOB_RECURSE MS_OPTIMIZER_910B RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "./ir_fusion_infer/matmul_sigmoid_add_fusion.cc" "./ir_fusion_infer/matmul_sigmoid_cast_add_fusion.cc" "./ir_fusion_infer/matmul_elemwise_fusion.cc" + "./ir_fusion_infer/inference_qbmm_elemwise_fusion.cc" "./ir_fusion_infer/remove_fa_tensor_to_tuple_ops.cc" "./ir_fusion_infer/transpose_batch_matmul_transpose_fusion.cc" "./ir_fusion_infer/moe_init_routing_dyn_quantv2.cc" diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_matmul_split_fusion.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_matmul_split_fusion.cc index b8cfc3c73c8..62cf24e48cf 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_matmul_split_fusion.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_matmul_split_fusion.cc @@ -23,9 +23,11 @@ #include "include/backend/anf_runtime_algorithm.h" #include "include/common/utils/anfalgo.h" #include "include/common/utils/utils.h" +#include "mindspore/ops/op_def/nn_optimizer_ops.h" #include "utils/ms_context.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" +#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_q.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" @@ -43,33 +45,45 @@ bool InferenceMatmulSplitFusion::Run(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(ms_context); constexpr auto kInferenceMatmulSplitSiluName = "InferenceMatmulSplitSilu"; constexpr auto kInferenceMatmulSplitName = "InferenceMatmulSplit"; + constexpr auto kInferenceMatmulSplitSiluFastgeluAddMulName = "InferenceGatedFFN"; auto enable_op_list = ms_context->ms_internal_enable_custom_kernel_list(); + auto enable_fuse_gated_ffn = (std::find(enable_op_list.begin(), enable_op_list.end(), + kInferenceMatmulSplitSiluFastgeluAddMulName) != enable_op_list.end()); auto enable_fusion = (std::find(enable_op_list.begin(), enable_op_list.end(), kInferenceMatmulSplitName) != enable_op_list.end()); - if (!enable_fusion) { + if (!enable_fusion && !enable_fuse_gated_ffn) { return false; } - enable_fusion_silu = - (std::find(enable_op_list.begin(), enable_op_list.end(), kInferenceMatmulSplitSiluName) != enable_op_list.end()); + enable_fusion_silu = enable_fusion && (std::find(enable_op_list.begin(), enable_op_list.end(), + kInferenceMatmulSplitSiluName) != enable_op_list.end()); std::string pattern_name = ""; auto node_list = TopoSort(graph->output()); std::reverse(node_list.begin(), node_list.end()); for (const auto &node : node_list) { + bool fuse_gated_ffn = false; if (node == nullptr || !node->isa()) { continue; } auto cnode = node->cast(); auto node_name = common::AnfAlgo::GetCNodeName(cnode); - if (node_name != prim::kPrimSplitWithSize->name() && node_name != prim::kPrimSiLU->name()) { + if (enable_fuse_gated_ffn && (node_name == prim::kPrimMul->name())) { + // last node is Mul and Fusion is allowed + fuse_gated_ffn = true; + } else if (node_name != prim::kPrimSplitWithSize->name() && node_name != prim::kPrimSiLU->name()) { continue; } if (visited_cnodes.find(cnode) != visited_cnodes.end()) { continue; } - pattern_name = GetFusionPatternName(cnode); + if (fuse_gated_ffn) { + pattern_name = GetGatedFFNFusionPatternName(cnode); + } else if (enable_fusion) { + pattern_name = GetFusionPatternName(cnode); + } MS_LOG(DEBUG) << "fusion pattern is : " << pattern_name; if (!pattern_name.empty()) { + MS_LOG(DEBUG) << "pattern name is not empty. node name is " << node_name; auto new_node = Process(pattern_name, graph, node); changed |= new_node != nullptr; } @@ -125,6 +139,96 @@ std::string InferenceMatmulSplitFusion::GetSplitFusionPatternName(const CNodePtr return pattern_name; } +std::string InferenceMatmulSplitFusion::GetSiluMulPattern(const CNodePtr &mul_input0_node, + const CNodePtr &mul_input1_node) const { + // in this branch we aim to find the first pattern -- kPatternNameMatMulSplitSiluMul + std::string pattern_name = ""; + auto silu_input_node = common::AnfAlgo::GetInputNode(mul_input0_node->cast(), kIndex0); + auto silu_input_name = common::AnfAlgo::GetCNodeName(silu_input_node); + auto tuple_input_node = common::AnfAlgo::GetInputNode(mul_input1_node->cast(), kIndex0); + auto tuple_input_name = common::AnfAlgo::GetCNodeName(tuple_input_node); + if ((silu_input_name == prim::kPrimTupleGetItem->name()) && (tuple_input_name == prim::kPrimSplitWithSize->name())) { + auto tuple2_input_node = common::AnfAlgo::GetInputNode(silu_input_node->cast(), kIndex0); + auto split_input_node = common::AnfAlgo::GetInputNode(tuple_input_node->cast(), kIndex0); + auto split_input_name = common::AnfAlgo::GetCNodeName(split_input_node); + if ((tuple_input_node == tuple2_input_node) && (split_input_name == prim::kPrimReshape->name())) { + auto reshape_input_node = common::AnfAlgo::GetInputNode(split_input_node->cast(), kIndex0); + auto reshape_input_name = common::AnfAlgo::GetCNodeName(reshape_input_node); + if (reshape_input_name == prim::kPrimMatMul->name()) { + pattern_name = kPatternNameMatMulSplitSiluMul; + } else if (reshape_input_name == prim::kPrimQuantBatchMatmul->name()) { + pattern_name = kPatternNameQMatMulSplitSiluMul; + } + } + } + return pattern_name; +} + +std::string InferenceMatmulSplitFusion::GetSiluFastGeluAddMulPattern(const CNodePtr &mul_input0_node, + const CNodePtr &mul_input1_node) const { + // in this branch we aim to find the second pattern -- kPatternNameMatMulSplitSiluFastgeluAddMul + std::string pattern_name = ""; + auto add_input0_node = common::AnfAlgo::GetInputNode(mul_input0_node->cast(), kIndex0); + auto add_input0_name = common::AnfAlgo::GetCNodeName(add_input0_node); + auto add_input1_node = common::AnfAlgo::GetInputNode(mul_input0_node->cast(), kIndex1); + auto add_input1_name = common::AnfAlgo::GetCNodeName(add_input1_node); + auto tuple_input_node = common::AnfAlgo::GetInputNode(mul_input1_node->cast(), kIndex0); + auto tuple_input_name = common::AnfAlgo::GetCNodeName(tuple_input_node); + if ((add_input0_name == prim::kPrimSiLU->name()) && (add_input1_name == prim::kPrimFastGeLU->name()) && + (tuple_input_name == prim::kPrimSplitWithSize->name())) { + auto silu_input_node = common::AnfAlgo::GetInputNode(add_input0_node->cast(), kIndex0); + auto silu_input_name = common::AnfAlgo::GetCNodeName(silu_input_node); + auto fastgelu_input_node = common::AnfAlgo::GetInputNode(add_input1_node->cast(), kIndex0); + auto fastgelu_input_name = common::AnfAlgo::GetCNodeName(fastgelu_input_node); + if ((silu_input_name == prim::kPrimTupleGetItem->name()) && + (fastgelu_input_name == prim::kPrimTupleGetItem->name())) { + auto tuple2_input_node = common::AnfAlgo::GetInputNode(silu_input_node->cast(), kIndex0); + auto tuple2_input_name = common::AnfAlgo::GetCNodeName(tuple2_input_node); + auto tuple3_input_node = common::AnfAlgo::GetInputNode(fastgelu_input_node->cast(), kIndex0); + if ((tuple2_input_name == prim::kPrimSplitWithSize->name()) && (tuple2_input_node == tuple3_input_node) && + (tuple2_input_node == tuple_input_node)) { + auto split_input_node = common::AnfAlgo::GetInputNode(tuple_input_node->cast(), kIndex0); + auto split_input_name = common::AnfAlgo::GetCNodeName(split_input_node); + if (split_input_name == prim::kPrimReshape->name()) { + auto reshape_input_node = common::AnfAlgo::GetInputNode(split_input_node->cast(), kIndex0); + auto reshape_input_name = common::AnfAlgo::GetCNodeName(reshape_input_node); + if (reshape_input_name == prim::kPrimMatMul->name()) { + pattern_name = kPatternNameMatMulSplitSiluFastgeluAddMul; + } else if (reshape_input_name == prim::kPrimQuantBatchMatmul->name()) { + pattern_name = kPatternNameQMatMulSplitSiluFastgeluAddMul; + } + } + } + } + } + return pattern_name; +} + +std::string InferenceMatmulSplitFusion::GetGatedFFNFusionPatternName(const CNodePtr &mul_cnode) const { + // in this Fusion Pass we are searching for two constructions: + // silu(w1 * x) ( w3 * x ) and + // [silu(w1 * x) + FastGeLU(w11 * x)] (w3 * x) + // in each of these constructions, in order to speed up computation, the weights are concatenated into a single W + // such that a single MatMul operation is performed, then the output is Reshape, Split and used accordingly + // This modification is performed in the python level, and here, during the fusion pass we capture it and replace it + // with a single kernel + std::string pattern_name = ""; + auto mul_i0_node = common::AnfAlgo::GetInputNode(mul_cnode, kIndex0); + auto mul_i1_node = common::AnfAlgo::GetInputNode(mul_cnode, kIndex1); + if (mul_i0_node == nullptr || !mul_i0_node->isa() || mul_i1_node == nullptr || !mul_i1_node->isa()) { + return ""; + } + auto mul_i0_name = common::AnfAlgo::GetCNodeName(mul_i0_node); + auto mul_i1_name = common::AnfAlgo::GetCNodeName(mul_i1_node); + if ((mul_i0_name == prim::kPrimSiLU->name()) && (mul_i1_name == prim::kPrimTupleGetItem->name())) { + pattern_name = GetSiluMulPattern(mul_i0_node->cast(), mul_i1_node->cast()); + } else if ((mul_i0_name == prim::kPrimAdd->name()) && (mul_i1_name == prim::kPrimTupleGetItem->name())) { + pattern_name = GetSiluFastGeluAddMulPattern(mul_i0_node->cast(), mul_i1_node->cast()); + } + MS_LOG(DEBUG) << " found pattern " << pattern_name; + return pattern_name; +} + std::string InferenceMatmulSplitFusion::GetFusionPatternName(const CNodePtr &cnode) const { std::string pattern_name = ""; auto cnode_name = common::AnfAlgo::GetCNodeName(cnode); @@ -229,8 +333,10 @@ PrimitivePtr InferenceMatmulSplitFusion::CreateMatmulSplitPrim(const CNodePtr &s MS_CHECK_TRUE_RET(!prim_name.empty(), nullptr); matmul_split_prim = std::make_shared(prim_name); MS_CHECK_TRUE_RET(matmul_split_prim != nullptr, nullptr); - auto split_size = split_cnode->input(kIndex2)->cast(); - matmul_split_prim->AddAttr("n_lens", split_size->value()); + if (split_size_len != 1) { + auto split_size = split_cnode->input(kIndex2)->cast(); + matmul_split_prim->AddAttr("n_lens", split_size->value()); + } return matmul_split_prim; } @@ -511,6 +617,244 @@ CNodePtr InferenceMatmulSplitFusion::CreateMatmulSplitSiluNode(const FuncGraphPt return new_item_cnode; } +CNodePtr InferenceMatmulSplitFusion::CreateMatmulSplitSiluMulNode(const FuncGraphPtr &func_graph, + const AnfNodePtr &node, + const std::string &pattern_name) const { + MS_LOG(DEBUG) << "start create MatmulSplitSiluMul node"; + MS_ASSERT(func_graph != nullptr && node != nullptr); + auto elem_mul_cnode = node->cast(); + MS_CHECK_TRUE_RET(elem_mul_cnode != nullptr, nullptr); + + auto silu_cnode = elem_mul_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(silu_cnode != nullptr, nullptr); + auto tuple_get_item_cnode = elem_mul_cnode->input(kIndex2)->cast(); + MS_CHECK_TRUE_RET(tuple_get_item_cnode != nullptr, nullptr); + auto split_cnode = tuple_get_item_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr); + + size_t split_size_len = kMatmulFfnSplitSizeLen; + auto reshape_cnode = split_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr); + auto tuple = reshape_cnode->input(kIndex2); + MS_CHECK_TRUE_RET(tuple != nullptr, nullptr); + + auto matmul_cnode = reshape_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(matmul_cnode != nullptr, nullptr); + MS_CHECK_TRUE_RET(matmul_cnode->func_graph() == split_cnode->func_graph(), nullptr); + + auto pre_reshape = matmul_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(pre_reshape != nullptr, nullptr); + + auto x_node = pre_reshape->input(kIndex1); + MS_EXCEPTION_IF_NULL(x_node); + auto weight_node = matmul_cnode->input(kIndex2); + MS_EXCEPTION_IF_NULL(weight_node); + const std::set support_dtype = {kNumberTypeFloat16, kNumberTypeBFloat16}; + if (!CheckSupportDataType(x_node, support_dtype) || !CheckMatMulDataFormat(matmul_cnode)) { + return nullptr; + } + auto fusion_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name); + fusion_prim->AddAttr("silu_position", MakeValue(1)); + std::vector matmul_split_inputs = {x_node, weight_node, tuple}; + auto matmul_split_cnode = func_graph->NewCNode(fusion_prim, matmul_split_inputs); + MS_EXCEPTION_IF_NULL(matmul_split_cnode); + + matmul_split_cnode->set_fullname_with_scope(matmul_cnode->fullname_with_scope() + "-SplitWithSiluMul"); + if (node->abstract() != nullptr) { + matmul_split_cnode->set_abstract(elem_mul_cnode->abstract()->Clone()); + } + + visited_cnodes.insert({silu_cnode, split_cnode}); + MS_LOG(DEBUG) << "create MatmulSplitSiluMul node success."; + return matmul_split_cnode; +} + +CNodePtr InferenceMatmulSplitFusion::CreateMatmulSplitSiluFastgeluAddMulNode(const FuncGraphPtr &func_graph, + const AnfNodePtr &node, + const std::string &pattern_name) const { + MS_LOG(DEBUG) << "start create MatmulSplitSiluFastgeluAddMul node"; + MS_ASSERT(func_graph != nullptr && node != nullptr); + auto elem_mul_cnode = node->cast(); + MS_CHECK_TRUE_RET(elem_mul_cnode != nullptr, nullptr); + + auto add_cnode = elem_mul_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(add_cnode != nullptr, nullptr); + auto silu_cnode = add_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(silu_cnode != nullptr, nullptr); + + auto tuple_get_item_cnode = elem_mul_cnode->input(kIndex2)->cast(); + MS_CHECK_TRUE_RET(tuple_get_item_cnode != nullptr, nullptr); + auto split_cnode = tuple_get_item_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr); + + size_t split_size_len = kMatmulQkvSplitSizeLen; + + auto reshape_cnode = split_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr); + auto tuple = reshape_cnode->input(kIndex2); + MS_CHECK_TRUE_RET(tuple != nullptr, nullptr); + + auto matmul_cnode = reshape_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(matmul_cnode != nullptr, nullptr); + MS_CHECK_TRUE_RET(matmul_cnode->func_graph() == split_cnode->func_graph(), nullptr); + + auto pre_reshape = matmul_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(pre_reshape != nullptr, nullptr); + + auto x_node = pre_reshape->input(kIndex1); + MS_EXCEPTION_IF_NULL(x_node); + auto weight_node = matmul_cnode->input(kIndex2); + MS_EXCEPTION_IF_NULL(weight_node); + const std::set support_dtype = {kNumberTypeFloat16, kNumberTypeBFloat16}; + if (!CheckSupportDataType(x_node, support_dtype) || !CheckMatMulDataFormat(matmul_cnode)) { + return nullptr; + } + auto fusion_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name); + fusion_prim->AddAttr("silu_position", MakeValue(1)); + std::vector matmul_split_inputs = {x_node, weight_node, tuple}; + auto matmul_split_cnode = func_graph->NewCNode(fusion_prim, matmul_split_inputs); + MS_EXCEPTION_IF_NULL(matmul_split_cnode); + + matmul_split_cnode->set_fullname_with_scope(matmul_cnode->fullname_with_scope() + "-SplitWithSiluFastGeluAddMul"); + if (node->abstract() != nullptr) { + matmul_split_cnode->set_abstract(elem_mul_cnode->abstract()->Clone()); + } + + visited_cnodes.insert({silu_cnode, split_cnode}); + MS_LOG(DEBUG) << "create MatmulSplitSiluFastgeluAddMul node success."; + return matmul_split_cnode; +} + +CNodePtr InferenceMatmulSplitFusion::CreateQMatmulSplitSiluMulNode(const FuncGraphPtr &func_graph, + const AnfNodePtr &node, + const std::string &pattern_name) const { + MS_LOG(DEBUG) << "start create MatmulSplitSiluMul node"; + MS_ASSERT(func_graph != nullptr && node != nullptr); + auto elem_mul_cnode = node->cast(); + MS_CHECK_TRUE_RET(elem_mul_cnode != nullptr, nullptr); + + auto silu_cnode = elem_mul_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(silu_cnode != nullptr, nullptr); + auto tuple_get_item_cnode = elem_mul_cnode->input(kIndex2)->cast(); + MS_CHECK_TRUE_RET(tuple_get_item_cnode != nullptr, nullptr); + auto split_cnode = tuple_get_item_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr); + + size_t split_size_len = kMatmulFfnSplitSizeLen; + auto reshape_cnode = split_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr); + auto reshape_tuple = reshape_cnode->input(kIndex2); + MS_CHECK_TRUE_RET(reshape_tuple != nullptr, nullptr); + + auto qbmm_cnode = reshape_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(qbmm_cnode != nullptr, nullptr); + MS_CHECK_TRUE_RET(qbmm_cnode->func_graph() == split_cnode->func_graph(), nullptr); + + auto pre_reshape = qbmm_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(pre_reshape != nullptr, nullptr); + auto qbmm_x = pre_reshape->input(kIndex1); + MS_EXCEPTION_IF_NULL(qbmm_x); + auto qbmm_w = qbmm_cnode->input(kIndex2); + MS_EXCEPTION_IF_NULL(qbmm_w); + auto input_bias = qbmm_cnode->input(kIndex5); + MS_EXCEPTION_IF_NULL(input_bias); + auto pertoken_scale = qbmm_cnode->input(kIndex6); + MS_EXCEPTION_IF_NULL(pertoken_scale); + if (!IsValueNode(pertoken_scale)) { + MS_LOG(INFO) << "Currently, do not support to fuse qbmm(pertoken) with split."; + return nullptr; + } + auto input_scale = qbmm_cnode->input(kIndex3); + MS_EXCEPTION_IF_NULL(input_scale); + const std::set support_dtype = {kNumberTypeInt8}; + if (!CheckSupportDataType(qbmm_x, support_dtype) || !CheckMatMulDataFormat(qbmm_cnode) || + !CheckSplitSize(qbmm_w, split_cnode)) { + return nullptr; + } + + auto fusion_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name); + fusion_prim->AddAttr("silu_position", MakeValue(1)); + std::vector qbmm_split_inputs = {qbmm_x, qbmm_w, reshape_tuple, input_bias, input_scale}; + + auto qmatmul_split_cnode = func_graph->NewCNode(fusion_prim, qbmm_split_inputs); + MS_EXCEPTION_IF_NULL(qmatmul_split_cnode); + + qmatmul_split_cnode->set_fullname_with_scope(qbmm_cnode->fullname_with_scope() + "-SplitWithSiluMul"); + if (node->abstract() != nullptr) { + qmatmul_split_cnode->set_abstract(elem_mul_cnode->abstract()->Clone()); + } + + visited_cnodes.insert({silu_cnode, split_cnode}); + MS_LOG(DEBUG) << "create QMatmulSplitSiluMul node success."; + return qmatmul_split_cnode; +} + +CNodePtr InferenceMatmulSplitFusion::CreateQMatmulSplitSiluFastgeluAddMulNode(const FuncGraphPtr &func_graph, + const AnfNodePtr &node, + const std::string &pattern_name) const { + MS_LOG(DEBUG) << "start create MatmulSplitSiluFastgeluAddMul node"; + MS_ASSERT(func_graph != nullptr && node != nullptr); + auto elem_mul_cnode = node->cast(); + MS_CHECK_TRUE_RET(elem_mul_cnode != nullptr, nullptr); + + auto add_cnode = elem_mul_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(add_cnode != nullptr, nullptr); + auto silu_cnode = add_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(silu_cnode != nullptr, nullptr); + + auto tuple_get_item_cnode = elem_mul_cnode->input(kIndex2)->cast(); + MS_CHECK_TRUE_RET(tuple_get_item_cnode != nullptr, nullptr); + auto split_cnode = tuple_get_item_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr); + + size_t split_size_len = kMatmulQkvSplitSizeLen; + auto reshape_cnode = split_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr); + auto reshape_tuple = reshape_cnode->input(kIndex2); + MS_CHECK_TRUE_RET(reshape_tuple != nullptr, nullptr); + + auto qbmm_cnode = reshape_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(qbmm_cnode != nullptr, nullptr); + MS_CHECK_TRUE_RET(qbmm_cnode->func_graph() == split_cnode->func_graph(), nullptr); + + auto pre_reshape = qbmm_cnode->input(kIndex1)->cast(); + MS_CHECK_TRUE_RET(pre_reshape != nullptr, nullptr); + auto qbmm_x = pre_reshape->input(kIndex1); + MS_EXCEPTION_IF_NULL(qbmm_x); + auto qbmm_w = qbmm_cnode->input(kIndex2); + MS_EXCEPTION_IF_NULL(qbmm_w); + auto input_bias = qbmm_cnode->input(kIndex5); + MS_EXCEPTION_IF_NULL(input_bias); + auto pertoken_scale = qbmm_cnode->input(kIndex6); + MS_EXCEPTION_IF_NULL(pertoken_scale); + if (!IsValueNode(pertoken_scale)) { + MS_LOG(INFO) << "Currently, do not support to fuse qbmm(pertoken) with split."; + return nullptr; + } + auto input_scale = qbmm_cnode->input(kIndex3); + MS_EXCEPTION_IF_NULL(input_scale); + const std::set support_dtype = {kNumberTypeInt8}; + if (!CheckSupportDataType(qbmm_x, support_dtype) || !CheckMatMulDataFormat(qbmm_cnode) || + !CheckSplitSize(qbmm_w, split_cnode)) { + return nullptr; + } + + auto fusion_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name); + fusion_prim->AddAttr("silu_position", MakeValue(1)); + std::vector qbmm_split_inputs = {qbmm_x, qbmm_w, reshape_tuple, input_bias, input_scale}; + auto qmatmul_split_cnode = func_graph->NewCNode(fusion_prim, qbmm_split_inputs); + MS_EXCEPTION_IF_NULL(qmatmul_split_cnode); + + qmatmul_split_cnode->set_fullname_with_scope(qbmm_cnode->fullname_with_scope() + "-SplitWithSiluFastGeluAddMul"); + if (node->abstract() != nullptr) { + qmatmul_split_cnode->set_abstract(elem_mul_cnode->abstract()->Clone()); + } + + visited_cnodes.insert({silu_cnode, split_cnode}); + MS_LOG(DEBUG) << "create MatmulSplitSiluFastgeluAddMul node success."; + return qmatmul_split_cnode; +} + CNodePtr InferenceMatmulSplitFusion::CreateMatmulBiasAddSplitSiluNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const std::string &pattern_name) const { @@ -642,8 +986,8 @@ AnfNodePtr InferenceMatmulSplitFusion::Process(const std::string &pattern_name, auto manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); - auto split_cnode = node->cast(); - MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr); + auto top_cnode = node->cast(); + MS_CHECK_TRUE_RET(top_cnode != nullptr, nullptr); CNodePtr matmul_split_cnode = nullptr; if (pattern_name == kPatternNameMatMulSplit) { @@ -655,7 +999,18 @@ AnfNodePtr InferenceMatmulSplitFusion::Process(const std::string &pattern_name, if (pattern_name == kPatternNameQuantbatchmatmulSplit) { matmul_split_cnode = CreateQuantbatchmatmulSplitNode(func_graph, node, pattern_name); } - + if (pattern_name == kPatternNameMatMulSplitSiluMul) { + matmul_split_cnode = CreateMatmulSplitSiluMulNode(func_graph, node, pattern_name); + } + if (pattern_name == kPatternNameMatMulSplitSiluFastgeluAddMul) { + matmul_split_cnode = CreateMatmulSplitSiluFastgeluAddMulNode(func_graph, node, pattern_name); + } + if (pattern_name == kPatternNameQMatMulSplitSiluMul) { + matmul_split_cnode = CreateQMatmulSplitSiluMulNode(func_graph, node, pattern_name); + } + if (pattern_name == kPatternNameQMatMulSplitSiluFastgeluAddMul) { + matmul_split_cnode = CreateQMatmulSplitSiluFastgeluAddMulNode(func_graph, node, pattern_name); + } if (pattern_name == kPatternNameMatMulSplitSilu) { matmul_split_cnode = CreateMatmulSplitSiluNode(func_graph, node, pattern_name); } @@ -667,7 +1022,7 @@ AnfNodePtr InferenceMatmulSplitFusion::Process(const std::string &pattern_name, } MS_CHECK_TRUE_RET(matmul_split_cnode != nullptr, nullptr); - (void)manager->Replace(split_cnode, matmul_split_cnode); + (void)manager->Replace(top_cnode, matmul_split_cnode); MS_LOG(DEBUG) << "MatmulSplit replace success"; return matmul_split_cnode; } diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_matmul_split_fusion.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_matmul_split_fusion.h index ebe2790e1f3..4b5582377e4 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_matmul_split_fusion.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_matmul_split_fusion.h @@ -35,6 +35,7 @@ namespace mindspore { namespace opt { constexpr auto kMatmulQkvSplitSizeLen = 3; constexpr auto kMatmulFfnSplitSizeLen = 2; +constexpr auto kMatmulGatedSizeLen = 1; constexpr auto kTuplePlaceHolderNum = 0; class InferenceMatmulSplitFusion : public Pass { @@ -46,6 +47,7 @@ class InferenceMatmulSplitFusion : public Pass { private: bool CheckReshapeNode(const AnfNodePtr &node) const; std::string GetFusionPatternName(const CNodePtr &cnode) const; + std::string GetGatedFFNFusionPatternName(const CNodePtr &cnode) const; std::string GetSplitFusionPatternName(const CNodePtr &cnode) const; bool CheckMatMulDataFormat(const CNodePtr &matmul_cnode) const; bool CheckSplitSize(const AnfNodePtr &weight_cnode, const CNodePtr &split_cnode) const; @@ -60,10 +62,19 @@ class InferenceMatmulSplitFusion : public Pass { CNodePtr CreateQuantbatchmatmulSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const std::string &) const; CNodePtr CreateMatmulSplitSiluNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const std::string &) const; + CNodePtr CreateMatmulSplitSiluMulNode(const FuncGraphPtr &f_graph, const AnfNodePtr &node, const std::string &) const; + CNodePtr CreateMatmulSplitSiluFastgeluAddMulNode(const FuncGraphPtr &f_graph, const AnfNodePtr &node, + const std::string &) const; + CNodePtr CreateQMatmulSplitSiluMulNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const std::string &) const; + CNodePtr CreateQMatmulSplitSiluFastgeluAddMulNode(const FuncGraphPtr &f_graph, const AnfNodePtr &node, + const std::string &) const; CNodePtr CreateMatmulBiasAddSplitSiluNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const std::string &) const; CNodePtr CreateQuantbatchmatmulSplitSiluNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const std::string &) const; + std::string GetSiluMulPattern(const CNodePtr &mul_input0_node, const CNodePtr &mul_input1_node) const; + std::string GetSiluFastGeluAddMulPattern(const CNodePtr &mul_input0_node, const CNodePtr &mul_input1_node) const; + bool enable_fusion_silu = false; mutable std::set visited_cnodes; @@ -72,6 +83,8 @@ class InferenceMatmulSplitFusion : public Pass { const std::string kPrimNameMatmulSplitOut2 = "MatmulSplitOut2"; const std::string kPrimNameMatmulSplitOut3 = "MatmulSplitOut3"; const std::string kPrimNameMatmulSplitSiluOut2 = "MatmulSplitSiluOut2"; + const std::string kPrimNameMatmulSplitSiluMulOut1 = "MatmulSplitSiluMulOut1"; + const std::string kPrimNameQMatmulSplitSiluMulOut1 = "QMatmulSplitSiluMulOut1"; const std::string kPrimNameMatmulBiasSplitOut2 = "MatmulBiasSplitOut2"; const std::string kPrimNameMatmulBiasSplitOut3 = "MatmulBiasSplitOut3"; const std::string kPrimNameMatmulBiasSplitSiluOut2 = "MatmulBiasSplitSiluOut2"; @@ -81,26 +94,34 @@ class InferenceMatmulSplitFusion : public Pass { const std::string kPatternNameMatMulSplit = "MatmulSplit"; const std::string kPatternNameMatMulSplitSilu = "MatmulSplitSilu"; + const std::string kPatternNameMatMulSplitSiluMul = "MatmulSplitSiluMul"; + const std::string kPatternNameMatMulSplitSiluFastgeluAddMul = "MatmulSplitSiluFastgeluAddMul"; + const std::string kPatternNameQMatMulSplitSiluMul = "QMatmulSplitSiluMul"; + const std::string kPatternNameQMatMulSplitSiluFastgeluAddMul = "QMatmulSplitSiluFastgeluAddMul"; const std::string kPatternNameMatMulBiasAddSplit = "MatmulBiasAddSplit"; const std::string kPatternNameMatMulBiasAddSplitSilu = "MatmulBiasAddSplitSilu"; const std::string kPatternNameQuantbatchmatmulSplit = "QuantbatchmatmulSplit"; const std::string kPatternNameQuantbatchmatmulSplitSilu = "QuantbatchmatmulSplitSilu"; std::map> PatternPrimMap = { - { - kMatmulQkvSplitSizeLen, - {{kPatternNameMatMulSplit, kPrimNameMatmulSplitOut3}, - {kPatternNameMatMulBiasAddSplit, kPrimNameMatmulBiasSplitOut3}, - {kPatternNameQuantbatchmatmulSplit, kPrimNameQuantbatchmatmulSplitOut3}}, - }, - + {kMatmulGatedSizeLen, + {{kPatternNameMatMulSplitSiluMul, kPrimNameMatmulSplitSiluMulOut1}, + {kPatternNameMatMulSplitSiluFastgeluAddMul, kPrimNameMatmulSplitSiluFastgeluAddMulOut1}}}, + {kMatmulQkvSplitSizeLen, + {{kPatternNameMatMulSplit, kPrimNameMatmulSplitOut3}, + {kPatternNameMatMulBiasAddSplit, kPrimNameMatmulBiasSplitOut3}, + {kPatternNameQuantbatchmatmulSplit, kPrimNameQuantbatchmatmulSplitOut3}, + {kPatternNameMatMulSplitSiluFastgeluAddMul, kPrimNameMatmulSplitSiluFastgeluAddMulOut1}, + {kPatternNameQMatMulSplitSiluFastgeluAddMul, kPrimNameQMatmulSplitSiluFastgeluAddMulOut1}}}, {kMatmulFfnSplitSizeLen, {{kPatternNameMatMulSplit, kPrimNameMatmulSplitOut2}, {kPatternNameMatMulSplitSilu, kPrimNameMatmulSplitSiluOut2}, {kPatternNameMatMulBiasAddSplit, kPrimNameMatmulBiasSplitOut2}, {kPatternNameMatMulBiasAddSplitSilu, kPrimNameMatmulBiasSplitSiluOut2}, {kPatternNameQuantbatchmatmulSplit, kPrimNameQuantbatchmatmulSplitOut2}, - {kPatternNameQuantbatchmatmulSplitSilu, kPrimNameQuantbatchmatmulSplitSiluOut2}}}}; + {kPatternNameQuantbatchmatmulSplitSilu, kPrimNameQuantbatchmatmulSplitSiluOut2}, + {kPatternNameMatMulSplitSiluMul, kPrimNameMatmulSplitSiluMulOut1}, + {kPatternNameQMatMulSplitSiluMul, kPrimNameQMatmulSplitSiluMulOut1}}}}; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_qbmm_elemwise_fusion.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_qbmm_elemwise_fusion.cc new file mode 100644 index 00000000000..25673601e97 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_qbmm_elemwise_fusion.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/ascend/optimizer/ir_fusion_infer/inference_qbmm_elemwise_fusion.h" +#include +#include +#include "backend/common/pass/common/gllo_utils.h" +#include "plugin/device/ascend/optimizer/ir_fusion_infer/inference_weight_preprocess_utils.h" +#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" +#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_q.h" +#include "mindspore/ops/op_def/nn_ops.h" +#include "mindspore/ops/op_def/math_ops.h" +#include "include/backend/anf_runtime_algorithm.h" +#include "include/common/utils/anfalgo.h" +#include "include/common/utils/utils.h" +#include "utils/ms_context.h" + +namespace mindspore { +namespace opt { + +CNodePtr QMatmulElemFusion::CreateQbmmElemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_LOG(DEBUG) << "start CreateQbmmElemNode"; + MS_ASSERT(func_graph != nullptr && node != nullptr && equiv != nullptr); + auto qbmm_prim = std::make_shared("QuantBatchMatmul"); + qbmm_prim->AddAttr("ElemwiseType", MakeValue("fastgelu")); + std::vector inputs = {x_node_, w_node_, scale_node_, + offset_node_, bias_node_, pertoken_scale_node_, + trans_a_node_, trans_b_node_, out_dtype_node_}; + auto new_qbmm_node = func_graph->NewCNode(qbmm_prim, inputs); + MS_CHECK_TRUE_RET(new_qbmm_node != nullptr, nullptr); + new_qbmm_node->set_scope(node->scope()); + + if (node->abstract() != nullptr) { + new_qbmm_node->set_abstract(node->abstract()->Clone()); + } + MS_LOG(DEBUG) << "create QbmmElem node success."; + return new_qbmm_node; +} + +std::vector QMatmulElemFusion::MustExistPrimitiveName() const { + std::vector ret{prim::kPrimQuantBatchMatmul->name(), prim::kPrimFastGeLU->name()}; + return ret; +} + +void QMatmulElemFusion::SetInternalNodes(const EquivPtr &equiv) const { + x_node_ = utils::cast((*equiv)[x_]); + MS_ASSERT(x_node_ != nullptr); + w_node_ = utils::cast((*equiv)[w_]); + MS_ASSERT(w_node_ != nullptr); + scale_node_ = utils::cast((*equiv)[scale_]); + MS_ASSERT(scale_node_ != nullptr); + offset_node_ = utils::cast((*equiv)[offset_]); + MS_ASSERT(offset_node != nullptr); + bias_node_ = utils::cast((*equiv)[bias_]); + MS_ASSERT(bias_node_ != nullptr); + pertoken_scale_node_ = utils::cast((*equiv)[pertoken_scale_]); + MS_ASSERT(pertoken_scale_node_ != nullptr); + trans_a_node_ = utils::cast((*equiv)[trans_a_]); + MS_ASSERT(trans_a_node != nullptr); + trans_b_node_ = utils::cast((*equiv)[trans_b_]); + MS_ASSERT(trans_b_node != nullptr); + out_dtype_node_ = utils::cast((*equiv)[out_dtype_]); + MS_ASSERT(out_dtype_node_ != nullptr); +} + +const BaseRef QMatmulElemFusion::DefinePattern() const { + if (!Init()) { + MS_LOG(DEBUG) << "initial member failed."; + return {}; + } + VectorRef qbmm_ref({qbmm_prim_, x_, w_, scale_, offset_, bias_, pertoken_scale_, trans_a_, trans_b_, out_dtype_}); + bias_tensor_ = std::make_shared(); + auto is_fast = std::make_shared(IsSpecifiedNode<&prim::kPrimFastGeLU>); + MS_CHECK_TRUE_RET(is_fast != nullptr, {}); + VectorRef fast_ref({is_fast, qbmm_ref}); + return fast_ref; +} + +const AnfNodePtr QMatmulElemFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto const &soc_version = ms_context->ascend_soc_version(); + if (!soc_version.empty() && soc_version != "ascend310p") { + return nullptr; + } + constexpr auto kQbmmElemName = "MatMulElemwise"; + if (func_graph == nullptr || node == nullptr || equiv == nullptr || !PassEnable(kQbmmElemName)) { + return nullptr; + } + + SetInternalNodes(equiv); + if (!IsValueNode(pertoken_scale_node_)) { + MS_LOG(INFO) << "Currently, do not support to fuse qbmm(pertoken) with add."; + return nullptr; + } + CheckIOValid(); + auto cnode = CreateQbmmElemNode(func_graph, node, equiv); + return cnode; +} + +bool QMatmulElemFusion::CheckIOValid() const { + if (!CheckSupportDataType(scale_node_, {kNumberTypeInt64}) || !CheckSupportDataType(bias_node_, {kNumberTypeInt32})) { + return false; + } + auto dtype_value = GetValue(out_dtype_node_->cast()->value()); + if (dtype_value != static_cast(kNumberTypeFloat16)) { + return false; + } + auto bias_shape = common::AnfAlgo::GetOutputInferShape(bias_node_, kIndex0); + auto scale_shape = common::AnfAlgo::GetOutputInferShape(scale_node_, kIndex0); + if (bias_shape.size() != 1 || scale_shape.size() != 1 || bias_shape[0] != scale_shape[0]) { + return false; + } + auto scale_param = GetParamFromLoad(scale_node_->cast(), false); + if (!scale_param) { + return false; + } + auto bias_param = GetParamFromLoad(bias_node_->cast(), false); + if (!bias_param) { + return false; + } + return true; +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_qbmm_elemwise_fusion.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_qbmm_elemwise_fusion.h new file mode 100644 index 00000000000..c0579d71b9b --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_qbmm_elemwise_fusion.h @@ -0,0 +1,46 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_ASCEND_IR_FUSION_INFER_INFERENCE_QBMM_ELEMWISE_FUSION_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_ASCEND_IR_FUSION_INFER_INFERENCE_QBMM_ELEMWISE_FUSION_H_ + +#include +#include +#include +#include +#include "plugin/device/ascend/optimizer/ir_fusion_infer/inference_qbmm_fusion_base.h" +#include "include/backend/optimizer/optimizer.h" +#include "mindspore/ops/op_def/math_ops.h" + +namespace mindspore { +namespace opt { +class QMatmulElemFusion : public QbmmFusionBase { + public: + explicit QMatmulElemFusion(bool multigraph = true, const string &pass_name = "quant_matmul_elemwise_fusion") + : QbmmFusionBase(pass_name, multigraph) {} + ~QMatmulElemFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const override; + + private: + void SetInternalNodes(const EquivPtr &equiv) const; + bool CheckIOValid() const; + CNodePtr CreateQbmmElemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const; + std::vector MustExistPrimitiveName() const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_ASCEND_IR_FUSION_INFER_INFERENCE_QBMM_ELEMWISE_FUSION_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/matmul_elemwise_fusion.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/matmul_elemwise_fusion.cc index 11677373dd6..de098484adb 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/matmul_elemwise_fusion.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/matmul_elemwise_fusion.cc @@ -25,6 +25,7 @@ #include "utils/ms_context.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" +#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" @@ -41,7 +42,8 @@ bool IsElemNode(const BaseRef &ref) { if (utils::isa(ref)) { AnfNodePtr node = utils::cast(ref); MS_EXCEPTION_IF_NULL(node); - if (IsOneOfPrimitive(node, {prim::kPrimBiasAdd, prim::kPrimAdd, prim::kPrimReLU, prim::kPrimGeLU})) { + if (IsOneOfPrimitive(node, + {prim::kPrimBiasAdd, prim::kPrimAdd, prim::kPrimReLU, prim::kPrimGeLU, prim::kPrimFastGeLU})) { return true; } } @@ -54,6 +56,7 @@ std::string MatmulElemFusion::GetElemwiseType(const CNodePtr &elemwise_node) con static const std::map kOpElemiseTypeMap = {{prim::kPrimBiasAdd->name(), "bias_add"}, {prim::kPrimAdd->name(), "bias_add"}, {prim::kPrimReLU->name(), "relu"}, + {prim::kPrimFastGeLU->name(), "fastgelu"}, {prim::kPrimGeLU->name(), "gelu"}}; return kOpElemiseTypeMap.at(common::AnfAlgo::GetCNodeName(elemwise_node)); } @@ -88,10 +91,10 @@ const AnfNodePtr MatmulElemFusion::Process(const FuncGraphPtr &func_graph, const auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); auto const &soc_version = ms_context->ascend_soc_version(); - if (!soc_version.empty() && soc_version != "ascend910b" && soc_version != "ascend910_93") { + if (!soc_version.empty() && soc_version != "ascend910b" && soc_version != "ascend910_93" && + soc_version != "ascend310p") { return nullptr; } - auto enable_op_list = ms_context->ms_internal_enable_custom_kernel_list(); bool enable_matmul_elemwise = (std::find(enable_op_list.begin(), enable_op_list.end(), "MatMulElemwise") != enable_op_list.end()); @@ -108,8 +111,12 @@ const AnfNodePtr MatmulElemFusion::Process(const FuncGraphPtr &func_graph, const auto matmul_cnode = elemwise_node->input(kIndex1)->cast(); MS_CHECK_TRUE_RET(matmul_cnode != nullptr, {}); MS_CHECK_TRUE_RET(matmul_cnode->func_graph() == elemwise_node->func_graph(), {}); - std::string elemwise_type = GetElemwiseType(elemwise_node); + const std::string fastgelu_str = "fastgelu"; + if (elemwise_type != fastgelu_str && !soc_version.empty() && soc_version != "ascend910b" && + soc_version != "ascend910_93") { + return nullptr; + } const std::string bias_add_str = "bias_add"; if (elemwise_type == bias_add_str && (common::AnfAlgo::GetPrevNodeOutputInferShape(node, 1).size() > 1 || common::AnfAlgo::GetOutputInferDataType(node, 0) != kFloat16->type_id())) { diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index 092a944925d..5f53c1a2ef5 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -738,15 +738,14 @@ inline std::string SetToString(const std::set &kernel_list) { } void MsContext::SetMsInternalEnableCustomKernelList() { - if (!ms_internal_enable_custom_kernel_list_.empty()) { - return; - } const std::string kDefaultEnabledOpList = "MatMul,RmsNorm,Add,Sub,FlashAttentionScore,PagedAttention,PagedAttentionMask,AddRmsNorm,AddLayerNorm," "MatMulAllReduce,InferenceMatmulSplit,AddRmsNormQuantV2,InferenceSwiGLU,QbmmAllReduceAdd,QbmmAdd," - "AddRmsNormDynamicQuant,MatMulElemwise,RmsNormQuant,MatMulSigmoidCastAdd,TransposeBatchMatmulTranspose," + "AddRmsNormDynamicQuant,MatMulElemwise,RmsNormQuant,MatMulSigmoidCastAdd,SwiGLUDynamicQuant"; + const std::string k310pDefaultEnabledOpList = + "MatMul,QuantBatchMatmul,QuantLinearSparse,QbmmAllReduceAdd,QbmmAdd,InferenceGatedFFN,MatMulElemwise," + "AddRmsNormDynamicQuant,RmsNormQuant,MatMulSigmoidCastAdd,TransposeBatchMatmulTranspose," "FusedAddTopKDiv,SwiGLUDynamicQuant,SwiGLUReshapeDynamicQuant,QbmmAllReduceConvertBias"; - const std::string k310pDefaultEnabledOpList = "MatMul,QuantBatchMatmul,QuantLinearSparse,QbmmAllReduceAdd,QbmmAdd"; auto internal_op_boost_env = common::GetEnv("MS_ENABLE_INTERNAL_BOOST"); bool is_enable_internal_op = true; bool is_310p = ascend_soc_version() == "ascend310p"; diff --git a/mindspore/ops/infer/ops_func_impl/matmul_split_silu_fastgelu_add_mul_out1.cc b/mindspore/ops/infer/ops_func_impl/matmul_split_silu_fastgelu_add_mul_out1.cc new file mode 100644 index 00000000000..2f4c0bcb56c --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/matmul_split_silu_fastgelu_add_mul_out1.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "infer/ops_func_impl/matmul_split_silu_fastgelu_add_mul_out1.h" +#include "infer/ops_func_impl/matmul_fusion_utils.h" +#include "mindspore/ops/ops_utils/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/convert_utils_base.h" +#include "utils/shape_utils.h" + +namespace mindspore { +namespace ops { + +BaseShapePtr MatmulSplitSiluFastgeluAddMulOut1FuncImpl::InferShape( + const PrimitivePtr &primitive, const std::vector &input_args) const { + auto op_name = primitive->name(); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape]; + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape]; + if (IsDynamicRank(x_shape) || IsDynamicRank(w_shape)) { + MS_LOG(EXCEPTION) << "For " << op_name << ", dynamic rank is not supported"; + } + constexpr size_t kSize2 = 2; + constexpr size_t kSize3 = 3; + const size_t x_rank = x_shape.size(); + const size_t w_rank = w_shape.size(); + MS_CHECK_VALUE(x_rank == kSize3, + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', x_rank should be 3.")); + + MS_CHECK_VALUE(w_rank == kSize2, + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', w_rank should be 2.")); + + auto b = x_shape[0]; // in matmul, m = b * s + auto s = x_shape[1]; + auto k = x_shape[2]; + auto k0 = w_shape[1]; + MS_CHECK_VALUE(k == k0, CheckAndConvertUtils::FormatCommMsg( + "For '" + primitive->name() + "', the K axis of all inputs must have the same length.")); + + MS_CHECK_VALUE(primitive->HasAttr("n_lens"), + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', op must have attr 'n_lens'.")); + + std::vector n_len_list = GetValue>(primitive->GetAttr("n_lens")); + MS_CHECK_VALUE( + (n_len_list.size() == kSize2 || n_len_list.size() == kSize3), + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', attr 'n_lens' must have 2 or 3 value.")); + + ShapeVector output_0_shape = {b, s, n_len_list[kIndex0]}; + std::vector shape_lists; + (void)shape_lists.emplace_back(std::make_shared(output_0_shape)); + return std::make_shared(shape_lists); +} + +TypePtr MatmulSplitSiluFastgeluAddMulOut1FuncImpl::InferType(const PrimitivePtr &primitive, + const std::vector &input_args) const { + return MatmulFusionUtils::InferenceMultiMatmulInferType(primitive, input_args); +} + +} // namespace ops +} // namespace mindspore diff --git a/mindspore/ops/infer/ops_func_impl/matmul_split_silu_fastgelu_add_mul_out1.h b/mindspore/ops/infer/ops_func_impl/matmul_split_silu_fastgelu_add_mul_out1.h new file mode 100644 index 00000000000..9c4b8d1de64 --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/matmul_split_silu_fastgelu_add_mul_out1.h @@ -0,0 +1,34 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_MATMUL_SPLIT_SILU_FASTGELU_ADD_MUL_OUT1_H_ +#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_MATMUL_SPLIT_SILU_FASTGELU_ADD_MUL_OUT1_H_ +#include +#include +#include "ops/ops_func_impl/op_func_impl.h" + +namespace mindspore { +namespace ops { +class OPS_API MatmulSplitSiluFastgeluAddMulOut1FuncImpl : public OpFuncImpl { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const override; + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override; +}; + +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_MATMUL_SPLIT_SILU_FASTGELU_ADD_MUL_OUT1_H_ diff --git a/mindspore/ops/infer/ops_func_impl/matmul_split_silu_mul_out1.cc b/mindspore/ops/infer/ops_func_impl/matmul_split_silu_mul_out1.cc new file mode 100644 index 00000000000..bcc25f9fba4 --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/matmul_split_silu_mul_out1.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "infer/ops_func_impl/matmul_split_silu_mul_out1.h" +#include "infer/ops_func_impl/matmul_fusion_utils.h" +#include "mindspore/ops/ops_utils/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/convert_utils_base.h" +#include "utils/shape_utils.h" + +namespace mindspore { +namespace ops { + +BaseShapePtr MatmulSplitSiluMulOut1FuncImpl::InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const { + auto op_name = primitive->name(); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape]; + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape]; + if (IsDynamicRank(x_shape) || IsDynamicRank(w_shape)) { + MS_LOG(EXCEPTION) << "For " << op_name << ", dynamic rank is not supported"; + } + constexpr size_t kSize2 = 2; + constexpr size_t kSize3 = 3; + const size_t x_rank = x_shape.size(); + const size_t w_rank = w_shape.size(); + MS_CHECK_VALUE(x_rank == kSize3, + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', x_rank should be 3.")); + + MS_CHECK_VALUE(w_rank == kSize2, + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', w_rank should be 2.")); + + auto b = x_shape[0]; // in matmul, m = b * s + auto s = x_shape[1]; + auto k = x_shape[2]; + auto k0 = w_shape[1]; + MS_CHECK_VALUE(k == k0, CheckAndConvertUtils::FormatCommMsg( + "For '" + primitive->name() + "', the K axis of all inputs must have the same length.")); + + MS_CHECK_VALUE(primitive->HasAttr("n_lens"), + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', op must have attr 'n_lens'.")); + + std::vector n_len_list = GetValue>(primitive->GetAttr("n_lens")); + MS_CHECK_VALUE( + (n_len_list.size() == kSize2 || n_len_list.size() == kSize3), + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', attr 'n_lens' must have 2 or 3 value.")); + + ShapeVector output_0_shape = {b, s, n_len_list[kIndex0]}; + std::vector shape_lists; + (void)shape_lists.emplace_back(std::make_shared(output_0_shape)); + return std::make_shared(shape_lists); +} + +TypePtr MatmulSplitSiluMulOut1FuncImpl::InferType(const PrimitivePtr &primitive, + const std::vector &input_args) const { + return MatmulFusionUtils::InferenceMultiMatmulInferType(primitive, input_args); +} + +} // namespace ops +} // namespace mindspore diff --git a/mindspore/ops/infer/ops_func_impl/matmul_split_silu_mul_out1.h b/mindspore/ops/infer/ops_func_impl/matmul_split_silu_mul_out1.h new file mode 100644 index 00000000000..60fc5321440 --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/matmul_split_silu_mul_out1.h @@ -0,0 +1,34 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_MATMUL_SPLIT_SILU_MUL_OUT1_H_ +#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_MATMUL_SPLIT_SILU_MUL_OUT1_H_ +#include +#include +#include "ops/ops_func_impl/op_func_impl.h" + +namespace mindspore { +namespace ops { +class OPS_API MatmulSplitSiluMulOut1FuncImpl : public OpFuncImpl { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const override; + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override; +}; + +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_MATMUL_SPLIT_SILU_MUL_OUT1_H_ diff --git a/mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_fastgelu_add_mul_out1.cc b/mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_fastgelu_add_mul_out1.cc new file mode 100644 index 00000000000..a57d227a471 --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_fastgelu_add_mul_out1.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "infer/ops_func_impl/q_matmul_split_silu_fastgelu_add_mul_out1.h" +#include "infer/ops_func_impl/matmul_fusion_utils.h" +#include "mindspore/ops/ops_utils/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/convert_utils_base.h" +#include "utils/shape_utils.h" + +namespace mindspore { +namespace ops { + +BaseShapePtr QMatmulSplitSiluFastgeluAddMulOut1FuncImpl::InferShape( + const PrimitivePtr &primitive, const std::vector &input_args) const { + auto op_name = primitive->name(); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape]; + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape]; + if (IsDynamicRank(x_shape) || IsDynamicRank(w_shape)) { + MS_LOG(EXCEPTION) << "For " << op_name << ", dynamic rank is not supported"; + } + constexpr size_t kSize2 = 2; + constexpr size_t kSize3 = 3; + const size_t x_rank = x_shape.size(); + const size_t w_rank = w_shape.size(); + MS_CHECK_VALUE(x_rank == kSize3, + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', x_rank should be 3.")); + + MS_CHECK_VALUE(w_rank == kSize2, + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', w_rank should be 2.")); + + auto b = x_shape[0]; // in matmul, m = b * s + auto s = x_shape[1]; + auto k = x_shape[2]; + auto k0 = w_shape[1]; + MS_CHECK_VALUE(k == k0, CheckAndConvertUtils::FormatCommMsg( + "For '" + primitive->name() + "', the K axis of all inputs must have the same length.")); + + MS_CHECK_VALUE(primitive->HasAttr("n_lens"), + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', op must have attr 'n_lens'.")); + + std::vector n_len_list = GetValue>(primitive->GetAttr("n_lens")); + MS_CHECK_VALUE( + (n_len_list.size() == kSize2 || n_len_list.size() == kSize3), + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', attr 'n_lens' must have 2 or 3 value.")); + + ShapeVector output_0_shape = {b, s, n_len_list[kIndex0]}; + std::vector shape_lists; + (void)shape_lists.emplace_back(std::make_shared(output_0_shape)); + return std::make_shared(shape_lists); +} + +TypePtr QMatmulSplitSiluFastgeluAddMulOut1FuncImpl::InferType(const PrimitivePtr &primitive, + const std::vector &input_args) const { + return MatmulFusionUtils::InferenceMultiMatmulInferType(primitive, input_args); +} + +} // namespace ops +} // namespace mindspore diff --git a/mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_fastgelu_add_mul_out1.h b/mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_fastgelu_add_mul_out1.h new file mode 100644 index 00000000000..c932ca53b17 --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_fastgelu_add_mul_out1.h @@ -0,0 +1,34 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_QMATMUL_SPLIT_SILU_FASTGELU_ADD_MUL_OUT1_H_ +#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_QMATMUL_SPLIT_SILU_FASTGELU_ADD_MUL_OUT1_H_ +#include +#include +#include "ops/ops_func_impl/op_func_impl.h" + +namespace mindspore { +namespace ops { +class OPS_API QMatmulSplitSiluFastgeluAddMulOut1FuncImpl : public OpFuncImpl { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const override; + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override; +}; + +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_QMATMUL_SPLIT_SILU_FASTGELU_ADD_MUL_OUT1_H_ diff --git a/mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_mul_out1.cc b/mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_mul_out1.cc new file mode 100644 index 00000000000..bdcfb50b84a --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_mul_out1.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "infer/ops_func_impl/q_matmul_split_silu_mul_out1.h" +#include +#include "infer/ops_func_impl/matmul_fusion_utils.h" +#include "mindspore/ops/ops_utils/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/convert_utils_base.h" +#include "utils/shape_utils.h" + +namespace mindspore { +namespace ops { + +BaseShapePtr QMatmulSplitSiluMulOut1FuncImpl::InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const { + auto op_name = primitive->name(); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape]; + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape]; + if (IsDynamicRank(x_shape) || IsDynamicRank(w_shape)) { + MS_LOG(EXCEPTION) << "For " << op_name << ", dynamic rank is not supported"; + } + constexpr size_t kSize2 = 2; + constexpr size_t kSize3 = 3; + const size_t x_rank = x_shape.size(); + const size_t w_rank = w_shape.size(); + MS_CHECK_VALUE(x_rank == kSize3, + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', x_rank should be 3.")); + + MS_CHECK_VALUE(w_rank == kSize2, + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', w_rank should be 2.")); + + auto b = x_shape[0]; // in matmul, m = b * s + auto s = x_shape[1]; + auto k = x_shape[2]; + auto k0 = w_shape[1]; + MS_CHECK_VALUE(k == k0, CheckAndConvertUtils::FormatCommMsg( + "For '" + primitive->name() + "', the K axis of all inputs must have the same length.")); + + MS_CHECK_VALUE(primitive->HasAttr("n_lens"), + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', op must have attr 'n_lens'.")); + + std::vector n_len_list = GetValue>(primitive->GetAttr("n_lens")); + MS_CHECK_VALUE( + (n_len_list.size() == kSize2 || n_len_list.size() == kSize3), + CheckAndConvertUtils::FormatCommMsg("For '" + primitive->name() + "', attr 'n_lens' must have 2 or 3 value.")); + + ShapeVector output_0_shape = {b, s, n_len_list[kIndex0]}; + std::vector shape_lists; + (void)shape_lists.emplace_back(std::make_shared(output_0_shape)); + return std::make_shared(shape_lists); +} + +TypePtr QMatmulSplitSiluMulOut1FuncImpl::InferType(const PrimitivePtr &primitive, + const std::vector &input_args) const { + return MatmulFusionUtils::InferenceMultiMatmulInferType(primitive, input_args); +} + +} // namespace ops +} // namespace mindspore diff --git a/mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_mul_out1.h b/mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_mul_out1.h new file mode 100644 index 00000000000..085c8528a03 --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/q_matmul_split_silu_mul_out1.h @@ -0,0 +1,34 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_Q_MATMUL_SPLIT_SILU_MUL_OUT1_H_ +#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_Q_MATMUL_SPLIT_SILU_MUL_OUT1_H_ +#include +#include +#include "ops/ops_func_impl/op_func_impl.h" + +namespace mindspore { +namespace ops { +class OPS_API QMatmulSplitSiluMulOut1FuncImpl : public OpFuncImpl { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const override; + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override; +}; + +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_Q_MATMUL_SPLIT_SILU_MUL_OUT1_H_ diff --git a/mindspore/ops/op_def/nn_op_name.h b/mindspore/ops/op_def/nn_op_name.h index 0877e89d7aa..6d8b4d2f61e 100644 --- a/mindspore/ops/op_def/nn_op_name.h +++ b/mindspore/ops/op_def/nn_op_name.h @@ -158,6 +158,10 @@ constexpr auto kRNNTLossOpName = "RNNTLoss"; constexpr auto kAllFiniteOpName = "AllFinite"; constexpr auto kWeightQuantMatmulQkvOpName = "WeightQuantMatmulQkv"; constexpr auto kWeightQuantMatmulFfnOpName = "WeightQuantMatmulFfn"; +constexpr auto kPrimNameMatmulSplitSiluFastgeluAddMulOut1 = "MatmulSplitSiluFastgeluAddMulOut1"; +constexpr auto kPrimNameMatmulSplitSiluMulOut1 = "MatmulSplitSiluMulOut1"; +constexpr auto kPrimNameQMatmulSplitSiluFastgeluAddMulOut1 = "QMatmulSplitSiluFastgeluAddMulOut1"; +constexpr auto kPrimNameQMatmulSplitSiluMulOut1 = "QMatmulSplitSiluMulOut1"; } // namespace mindspore #endif // MINDSPORE_CORE_BASE_NN_OP_NAME_H_ diff --git a/mindspore/ops/op_def/yaml/infer/matmul_split_silu_fastgelu_add_mul_out1_op.yaml b/mindspore/ops/op_def/yaml/infer/matmul_split_silu_fastgelu_add_mul_out1_op.yaml new file mode 100644 index 00000000000..c74d2f6237f --- /dev/null +++ b/mindspore/ops/op_def/yaml/infer/matmul_split_silu_fastgelu_add_mul_out1_op.yaml @@ -0,0 +1,16 @@ +#operator MatmulSplitSiluFastGeluAddMulOut1 +matmul_split_silu_fastgelu_add_mul_out1: + args: + input: + dtype: tensor + weight: + dtype: tensor + reshape: + dtype: tuple[int] + class: + disable: True + function: + disable: True + returns: + output0: + dtype: tensor diff --git a/mindspore/ops/op_def/yaml/infer/matmul_split_silu_mul_out1_op.yaml b/mindspore/ops/op_def/yaml/infer/matmul_split_silu_mul_out1_op.yaml new file mode 100644 index 00000000000..b72c63d5858 --- /dev/null +++ b/mindspore/ops/op_def/yaml/infer/matmul_split_silu_mul_out1_op.yaml @@ -0,0 +1,16 @@ +#operator MatmulSplitSiluMulOut1 +matmul_split_silu_mul_out1: + args: + input: + dtype: tensor + weight: + dtype: tensor + reshape: + dtype: tuple[int] + class: + disable: True + function: + disable: True + returns: + output0: + dtype: tensor diff --git a/mindspore/ops/op_def/yaml/infer/q_matmul_split_silu_fastgelu_add_mul_out1_op.yaml b/mindspore/ops/op_def/yaml/infer/q_matmul_split_silu_fastgelu_add_mul_out1_op.yaml new file mode 100644 index 00000000000..39700e1b787 --- /dev/null +++ b/mindspore/ops/op_def/yaml/infer/q_matmul_split_silu_fastgelu_add_mul_out1_op.yaml @@ -0,0 +1,20 @@ +#operator QMatmulSplitSiluFastGeluAddMulOut1 +q_matmul_split_silu_fastgelu_add_mul_out1: + args: + input: + dtype: tensor + weight: + dtype: tensor + reshape: + dtype: tuple[int] + bias: + dtype: tensor + scale: + dtype: tensor + class: + disable: True + function: + disable: True + returns: + output0: + dtype: tensor diff --git a/mindspore/ops/op_def/yaml/infer/q_matmul_split_silu_mul_out1_op.yaml b/mindspore/ops/op_def/yaml/infer/q_matmul_split_silu_mul_out1_op.yaml new file mode 100644 index 00000000000..38ad6a65655 --- /dev/null +++ b/mindspore/ops/op_def/yaml/infer/q_matmul_split_silu_mul_out1_op.yaml @@ -0,0 +1,20 @@ +#operator QMatmulSplitSiluMulOut1 +q_matmul_split_silu_mul_out1: + args: + input: + dtype: tensor + weight: + dtype: tensor + reshape: + dtype: tuple[int] + bias: + dtype: tensor + scale: + dtype: tensor + class: + disable: True + function: + disable: True + returns: + output0: + dtype: tensor diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index cab2594ef05..d483f910b49 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -39,6 +39,7 @@ from ..auto_generate import (CeLU, Flatten, LogSoftmax, LogSoftmaxExt, GLU, ReLU GridSampler3D, GridSampler2D, LayerNorm, LayerNormExt, HShrink, AdamWeightDecay, Dropout, ApplyRotaryPosEmb, GroupTopk, PagedAttention, PagedAttentionMask, ReshapeAndCache, KvScaleCache, FlashAttentionScore, PromptFlashAttention, Embedding, UpsampleNearest1D, UpsampleNearest2D, + DynamicNTK, UpsampleNearest3D, UpsampleTrilinear3D, SoftMarginLoss, UpsampleBilinear2D, UpsampleLinear1D, BinaryCrossEntropy, BCEWithLogitsLoss, SoftShrink, AdaptiveMaxPool2D, -- Gitee From 8979416b271b6d0039b5ff41d02a904333d39fb1 Mon Sep 17 00:00:00 2001 From: ckey_Dou Date: Fri, 16 May 2025 17:20:47 +0800 Subject: [PATCH 3/4] support mla --- .../device/ascend/kernel/internal/mla.cc | 50 ++++ .../device/ascend/kernel/internal/mla.h | 31 ++ mindspore/ops/infer/ops_func_impl/mla.cc | 161 +++++++++++ mindspore/ops/infer/ops_func_impl/mla.h | 54 ++++ mindspore/ops/op_def/op_enum.cc | 7 +- mindspore/ops/op_def/op_enum.h | 2 + mindspore/ops/op_def/yaml/mla_op.yaml | 49 ++++ .../ops/ascend/test_internal_ops/test_mla.py | 270 ++++++++++++++++++ tests/ut/cpp/CMakeLists.txt | 2 +- tests/ut/cpp/internal/test_ops_mla.cc | 89 ++++++ 10 files changed, 713 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/internal/mla.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/internal/mla.h create mode 100644 mindspore/ops/infer/ops_func_impl/mla.cc create mode 100644 mindspore/ops/infer/ops_func_impl/mla.h create mode 100644 mindspore/ops/op_def/yaml/mla_op.yaml create mode 100644 tests/st/ops/ascend/test_internal_ops/test_mla.py create mode 100644 tests/ut/cpp/internal/test_ops_mla.cc diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/mla.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/mla.cc new file mode 100644 index 00000000000..f3f443836bd --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/mla.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/kernel/internal/mla.h" + +#include +#include "common/kernel.h" +#include "plugin/device/ascend/kernel/internal/internal_kernel_in_out_map.h" +#include "utils/llm_manager.h" +#include "plugin/device/ascend/kernel/internal/internal_kernel_utils.h" + +namespace mindspore { +namespace kernel { +internal::InternalOpPtr InternalMla::CreateKernel(const internal::InputsImmutableInfoList &inputs_ii, + const internal::OutputsImmutableInfoList &outputs_ii, + const std::vector &ms_inputs, + const std::vector &ms_outputs) { + internal::MLAParam param; + param.type = internal::MLAParam::kSplitCache; + param.head_size = static_cast(ms_inputs[kIndex10]->GetValueWithCheck()); + param.tor = ms_inputs[kIndex11]->GetValueWithCheck(); + param.kv_head = static_cast(ms_inputs[kIndex12]->GetValueWithCheck()); + param.mask_type = static_cast(ms_inputs[kIndex13]->GetValueWithCheck()); + param.is_ring = static_cast(ms_inputs[kIndex14]->GetValueWithCheck()); + + (void)GetSeqLenFromInputAndCheckUpadate(kernel_name_, {"q_seq_lens"}, ms_inputs[kIndex8], ¶m.q_seq_len); + (void)GetSeqLenFromInputAndCheckUpadate(kernel_name_, {"batch_valid_length"}, ms_inputs[kIndex9], ¶m.kv_seq_len); + + return internal::CreateMLAOp(inputs_ii, outputs_ii, param, internal::kInternalMLAOpName); +} + +MS_INTERNAL_KERNEL_FACTORY_REG(Mla, internal::kInternalMLAOpName, InternalMla); +REG_MS_TO_INTERNAL_IN_TENSOR_IDX_MAP(Mla, INPUT_NUM_8, INDEX_0, INDEX_1, INDEX_2, INDEX_3, INDEX_4, INDEX_5, INDEX_6, + INDEX_7); +REG_MS_TO_INTERNAL_OUT_TENSOR_IDX_MAP(Mla, OUTPUT_NUM_2, INDEX_0, INDEX_1); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/mla.h b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/mla.h new file mode 100644 index 00000000000..27993113e32 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/mla.h @@ -0,0 +1,31 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_INTERNAL_KERNEL_INTERNAL_MLA_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_INTERNAL_KERNEL_INTERNAL_MLA_H_ + +#include +#include +#include + +#include "plugin/device/ascend/kernel/internal/internal_kernel_mod.h" +#include "include/internal.h" + +namespace mindspore { +namespace kernel { +DECLARE_INTERNAL_KERNEL_MOD(Mla) +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_INTERNAL_KERNEL_INTERNAL_MLA_H_ diff --git a/mindspore/ops/infer/ops_func_impl/mla.cc b/mindspore/ops/infer/ops_func_impl/mla.cc new file mode 100644 index 00000000000..25c5afb134a --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/mla.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "infer/ops_func_impl/mla.h" +#include +#include +#include + +#include "mindspore/ops/op_def/nn_ops.h" +#include "mindspore/ops/op_def/op_enum.h" +#include "utils/check_convert_utils.h" +#include "ops/primitive_c.h" +#include "mindapi/helper.h" +#include "include/api/data_type.h" + +namespace mindspore { +namespace ops { +static constexpr auto kMLAQshapeRank = 3; +static constexpr auto kMLAKVshapeRank = 4; +static constexpr auto kMLABlockTablesRank = 2; +static constexpr auto kMLAMaskRank = 2; +static constexpr auto kMLADeqScaleRank = 1; +static constexpr auto kMLAMaskFreeLastDim = 128; +static constexpr auto kMLAQKVnopeHiddenSize = 512; +static constexpr auto kMLAQKropeHiddenSize = 64; + +static void CheckShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) { + auto q_nope_shape = input_infos[kMlaInputQnopeIndex]->GetShape(); + auto q_rope_shape = input_infos[kMlaInputQropeIndex]->GetShape(); + auto ctkv_shape = input_infos[kMlaInputKvCacheIndex]->GetShape(); + auto k_rope_shape = input_infos[kMlaInputKropeIndex]->GetShape(); + auto block_tables_shape = input_infos[kMlaInputBlockTablesIndex]->GetShape(); + auto q_len_shape = input_infos[kMlaInputQueryLensIndex]->GetShape(); + auto context_len_shape = input_infos[kMlaInputContextLensIndex]->GetShape(); + + if (!input_infos[kMlaInputQnopeIndex]->IsDynamic()) { + MS_CHECK_VALUE(q_nope_shape.size() == kMLAQshapeRank, + CheckAndConvertUtils::FormatCommMsg("For MLA The rank of q_nope must be ", kMLAQshapeRank, + ", but got shape: ", q_nope_shape)); + MS_CHECK_VALUE(q_nope_shape[q_nope_shape.size() - 1] == kMLAQKVnopeHiddenSize, + CheckAndConvertUtils::FormatCommMsg("For MLA The last dim of q_nope must be ", kMLAQKVnopeHiddenSize, + ", but got shape: ", q_nope_shape)); + } + + if (!input_infos[kMlaInputQropeIndex]->IsDynamic()) { + MS_CHECK_VALUE(q_rope_shape.size() == kMLAQshapeRank, + CheckAndConvertUtils::FormatCommMsg("For MLA The rank of q_rope must be ", kMLAQshapeRank, + ", but got shape: ", q_rope_shape)); + MS_CHECK_VALUE(q_rope_shape[q_rope_shape.size() - 1] == kMLAQKropeHiddenSize, + CheckAndConvertUtils::FormatCommMsg("For MLA The last dim of q_rope must be ", kMLAQKropeHiddenSize, + ", but got shape: ", q_rope_shape)); + } + + if (!input_infos[kMlaInputKvCacheIndex]->IsDynamic()) { + MS_CHECK_VALUE(ctkv_shape.size() == kMLAKVshapeRank, + CheckAndConvertUtils::FormatCommMsg("For MLA The rank of ctkv must be ", kMLAKVshapeRank, + ", but got shape: ", ctkv_shape)); + MS_CHECK_VALUE(ctkv_shape[ctkv_shape.size() - 1] == kMLAQKVnopeHiddenSize, + CheckAndConvertUtils::FormatCommMsg("For MLA The last dim of ctkv must be ", kMLAQKVnopeHiddenSize, + ", but got shape: ", ctkv_shape)); + } + + if (!input_infos[kMlaInputKropeIndex]->IsDynamic()) { + MS_CHECK_VALUE(k_rope_shape.size() == kMLAKVshapeRank, + CheckAndConvertUtils::FormatCommMsg("For MLA The rank of k_rope must be ", kMLAKVshapeRank, + ", but got shape: ", k_rope_shape)); + MS_CHECK_VALUE(k_rope_shape[k_rope_shape.size() - 1] == kMLAQKropeHiddenSize, + CheckAndConvertUtils::FormatCommMsg("For MLA The last dim of k_rope must be ", kMLAQKropeHiddenSize, + ", but got shape: ", k_rope_shape)); + } + + if (!input_infos[kMlaInputBlockTablesIndex]->IsDynamic()) { + MS_CHECK_VALUE(block_tables_shape.size() == kMLABlockTablesRank, + CheckAndConvertUtils::FormatCommMsg("For MLA The rank of block_tables must be ", kMLABlockTablesRank, + ", but got shape: ", block_tables_shape)); + } + + if (!input_infos[kMlaInputAttnMaskIndex]->IsNone() && !input_infos[kMlaInputAttnMaskIndex]->IsDynamic()) { + auto mask_shape = input_infos[kMlaInputAttnMaskIndex]->GetShape(); + auto mask_mode_value = input_infos[kMlaInputMaskModeIndex]->GetScalarValue(); + if (!mask_mode_value.has_value()) { + MS_EXCEPTION(ValueError) << "For MLA mask_mode must be constant but got variable."; + } + + auto mask_mode = mask_mode_value.value(); + if (mask_mode == MLAMode::MASK_SPEC || mask_mode == MLAMode::MASK_FREE) { + MS_CHECK_VALUE(mask_shape.size() == kMLAMaskRank, + CheckAndConvertUtils::FormatCommMsg("For MLA The rank of mask must be ", kMLAMaskRank, + ", but got shape: ", mask_shape)); + } + + if (mask_mode == MLAMode::MASK_FREE) { + MS_CHECK_VALUE(mask_shape[mask_shape.size() - 1] == kMLAMaskFreeLastDim, + CheckAndConvertUtils::FormatCommMsg("For MLA The last dim of mask must be ", kMLAMaskFreeLastDim, + ", when mask_mode is MASK_FREE but got shape: ", mask_shape)); + } + } + + if (!input_infos[kMlaInputDeqScaleQkIndex]->IsNone()) { + auto deq_scale_qk_shape = input_infos[kMlaInputDeqScaleQkIndex]->GetShape(); + MS_CHECK_VALUE(deq_scale_qk_shape.size() == kMLADeqScaleRank, + CheckAndConvertUtils::FormatCommMsg("For MLA The rank of deq_scale_qk must be ", kMLADeqScaleRank, + ", but got shape: ", deq_scale_qk_shape)); + } + + if (!input_infos[kMlaInputDeqScalePvIndex]->IsNone()) { + auto deq_scale_pv_shape = input_infos[kMlaInputDeqScalePvIndex]->GetShape(); + + MS_CHECK_VALUE(deq_scale_pv_shape.size() == kMLADeqScaleRank, + CheckAndConvertUtils::FormatCommMsg("For MLA The rank of deq_scale_pv must be ", kMLADeqScaleRank, + ", but got shape: ", deq_scale_pv_shape)); + } + + MS_CHECK_VALUE(q_len_shape.size() == kMLADeqScaleRank, + CheckAndConvertUtils::FormatCommMsg("For MLA The rank of q_seq_lens must be ", kMLADeqScaleRank, + ", but got shape: ", q_len_shape)); + MS_CHECK_VALUE(context_len_shape.size() == kMLADeqScaleRank, + CheckAndConvertUtils::FormatCommMsg("For MLA The rank of context_lengths must be ", kMLADeqScaleRank, + ", but got shape: ", context_len_shape)); +} + +ShapeArray MlaFuncImpl::InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const { + auto &q_nope_info = input_infos[kMlaInputQnopeIndex]; + auto q_nope_shape = q_nope_info->GetShape(); + auto is_ring_value = input_infos[kMlaInputIsRingIndex]->GetScalarValue(); + if (!is_ring_value.has_value()) { + MS_EXCEPTION(ValueError) << "For MLA, the ring must be a constant, but got a variable."; + } + + auto is_ring = is_ring_value.value(); + if (is_ring != 0) { + MS_EXCEPTION(ValueError) << "For MLA, ir_ring must be 0 now, but got: " << is_ring; + } + + CheckShape(primitive, input_infos); + + ShapeVector lse_out_shape{0}; + return {q_nope_shape, lse_out_shape}; +} + +std::vector MlaFuncImpl::InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const { + auto q_nope_type = input_infos[kMlaInputQnopeIndex]->GetType(); + auto q_rope_type = input_infos[kMlaInputQropeIndex]->GetType(); + + return {q_rope_type, q_nope_type}; +} +} // namespace ops +} // namespace mindspore diff --git a/mindspore/ops/infer/ops_func_impl/mla.h b/mindspore/ops/infer/ops_func_impl/mla.h new file mode 100644 index 00000000000..c501fa312ea --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/mla.h @@ -0,0 +1,54 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_MLA_H_ +#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_MLA_H_ +#include + +#include "ops/ops_func_impl/op_func_impl.h" + +namespace mindspore { +namespace ops { +enum MlaInputIndex : size_t { + kMlaInputQnopeIndex = 0, + kMlaInputQropeIndex, + kMlaInputKvCacheIndex, + kMlaInputKropeIndex, + kMlaInputBlockTablesIndex, + kMlaInputAttnMaskIndex, + kMlaInputDeqScaleQkIndex, + kMlaInputDeqScalePvIndex, + kMlaInputQueryLensIndex, + kMlaInputContextLensIndex, + kMlaInputNumHeadIndex, + kMlaInputScaleValueIndex, + kMlaInputNumKVHeadIndex, + kMlaInputMaskModeIndex, + kMlaInputIsRingIndex, + kMlaInputsNum +}; + +class OPS_API MlaFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override; + std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override; + bool GeneralInferRegistered() const override { return true; } +}; + +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_MLA_H_ diff --git a/mindspore/ops/op_def/op_enum.cc b/mindspore/ops/op_def/op_enum.cc index 8aa367afd4a..e4213bf34ce 100644 --- a/mindspore/ops/op_def/op_enum.cc +++ b/mindspore/ops/op_def/op_enum.cc @@ -212,7 +212,12 @@ REG_STRING_TO_ENUM_COMMON(kv_cache_quant_mode, StrToEnumMap{{"DEFAULT", PagedAtt // PagedAttentionMaskMode REG_STRING_TO_ENUM_COMMON(mask_mode, StrToEnumMap{{"MASK_DEFAULT", PagedAttentionMaskMode::MASK_DEFAULT}, - {"TRAPEZOIDAL", PagedAttentionMaskMode::TRAPEZOIDAL}}) + {"TRAPEZOIDAL", PagedAttentionMaskMode::TRAPEZOIDAL}, + {"MASK_NONE", MLAMode::MASK_NONE}, + {"MASK_NORM", MLAMode::MASK_NORM}, + {"MASK_ALIBI", MLAMode::MASK_ALIBI}, + {"MASK_SPEC", MLAMode::MASK_SPEC}, + {"MASK_FREE", MLAMode::MASK_FREE}}) // ErrorMode REG_STRING_TO_ENUM_SPECIAL(error_mode, StrToEnumMap{{"CYCLE", ErrorMode::CYCLE}, {"SPECIFIC", ErrorMode::SPECIFIC}}); diff --git a/mindspore/ops/op_def/op_enum.h b/mindspore/ops/op_def/op_enum.h index d1ac09a97b1..b253ae451bd 100644 --- a/mindspore/ops/op_def/op_enum.h +++ b/mindspore/ops/op_def/op_enum.h @@ -67,6 +67,8 @@ enum PagedAttentionKVCacheQuantMode : int64_t { DEFAULT = 0, PERTOKEN = 1 }; enum PagedAttentionMaskMode : int64_t { MASK_DEFAULT = 0, TRAPEZOIDAL = 1 }; +enum MLAMode : int64_t { MASK_NONE = 0, MASK_NORM = 1, MASK_ALIBI = 2, MASK_SPEC = 3, MASK_FREE = 4 }; + enum ErrorMode : int64_t { CYCLE = 0, SPECIFIC = 1 }; enum FlipMode : int64_t { BITFLIP = 0, BITFLIP_DESIGNED = 1, MULTIPLY = 2, MULTIPLY_MAX = 3 }; diff --git a/mindspore/ops/op_def/yaml/mla_op.yaml b/mindspore/ops/op_def/yaml/mla_op.yaml new file mode 100644 index 00000000000..6b498821661 --- /dev/null +++ b/mindspore/ops/op_def/yaml/mla_op.yaml @@ -0,0 +1,49 @@ +#operator Mla +mla: + args: + query: + dtype: tensor + q_rope: + dtype: tensor + kv_cache: + dtype: tensor + k_rope: + dtype: tensor + block_tables: + dtype: tensor + attn_mask: + dtype: tensor + default: None + deq_scale_qk: + dtype: tensor + default: None + deq_scale_pv: + dtype: tensor + default: None + q_seq_lens: + dtype: tensor + default: None + context_lens: + dtype: tensor + default: None + head_num: + dtype: int + default: 32 + scale_value: + dtype: float + default: 0.0 + kv_head_num: + dtype: int + default: 1 + mask_mode: + dtype: int + default: "'MASK_NONE'" + arg_handler: str_to_enum + is_ring: + dtype: int + default: 0 + returns: + attention_out: + dtype: tensor + lse: + dtype: tensor \ No newline at end of file diff --git a/tests/st/ops/ascend/test_internal_ops/test_mla.py b/tests/st/ops/ascend/test_internal_ops/test_mla.py new file mode 100644 index 00000000000..55d2703d890 --- /dev/null +++ b/tests/st/ops/ascend/test_internal_ops/test_mla.py @@ -0,0 +1,270 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""test mla""" + +import mindspore as ms +from mindspore import nn, ops, Tensor, context +from mindspore.ops.operations.nn_ops import PagedAttention +import numpy as np +import pytest + + +class MlaTestParam: + """MlaTestParam""" + def __init__(self, num_heads, kv_heads, block_size, head_size_nope, head_size_rope, num_blocks, + q_seq_lens : list, context_lengths : list, tor, nope_ms_dtype, rope_ms_dtype, mask_mode : str): + + self.num_heads = num_heads + self.kv_heads = kv_heads + self.block_size = block_size + self.head_size_nope = head_size_nope + self.head_size_rope = head_size_rope + self.num_blocks = num_blocks + self.q_seq_lens = q_seq_lens + self.context_lengths = context_lengths + self.tor = tor + self.nope_ms_dtype = nope_ms_dtype + self.rope_ms_dtype = rope_ms_dtype + self.mask_mode = mask_mode + self.mask_factor = -10000.0 if rope_ms_dtype == ms.float16 else 1.0 + + self.batch = len(q_seq_lens) + + self.max_context_len = max(context_lengths) + self.max_num_blocks_per_seq = (self.max_context_len + block_size - 1) // block_size + + self.num_tokens = (int)(np.array(q_seq_lens).sum()) + self.block_tables = self._build_block_tables() + + self._build_tensor_inputs() + + + def _build_np_mask(self): + """_build_np_mask""" + if self.mask_mode == "MASK_NONE": + return None + + if self.mask_mode == "MASK_SPEC": + pre_qseqlen = 0 + np_mask = np.zeros(shape=(self.num_tokens, self.max_context_len)).astype(np.float32) + for i in range(self.batch): + qseqlen = self.q_seq_lens[i] + kseqlen = self.context_lengths[i] + tri = np.ones((qseqlen, qseqlen)) + tri = np.triu(tri, 1) + tri *= self.mask_factor + np_mask[pre_qseqlen:(pre_qseqlen + qseqlen), kseqlen-qseqlen:kseqlen] = tri + pre_qseqlen += qseqlen + return np_mask + + if self.mask_mode == "MASK_FREE": + pass + + return None + + + def _build_block_tables(self): + """_build_block_tables""" + block_tables_list = [] + for i in range(self.num_tokens): + block_table = [ + i * self.max_num_blocks_per_seq + _ for _ in range(self.max_num_blocks_per_seq) + ] + block_tables_list.append(block_table) + return block_tables_list + + + def _build_tensor_inputs(self): + """_build_tensor_inputs""" + np_q_nope = np.random.uniform(-1.0, 1.0, size=(self.num_tokens, self.num_heads, self.head_size_nope)) + np_q_rope = np.random.uniform(-1.0, 1.0, size=(self.num_tokens, self.num_heads, self.head_size_rope)) + np_ctkv = np.random.uniform(-1.0, 1.0, size=(self.num_blocks, self.block_size, + self.kv_heads, self.head_size_nope)) + np_k_rope = np.random.uniform(-1.0, 1.0, size=(self.num_blocks, self.block_size, + self.kv_heads, self.head_size_rope)) + + np_context_lens = np.array(self.context_lengths).astype(np.int32) + np_q_seq_lens = np.array(self.q_seq_lens).astype(np.int32) + + self.q_nope_tensor = Tensor(np_q_nope, dtype=self.nope_ms_dtype) + self.q_rope_tensor = Tensor(np_q_rope, dtype=self.rope_ms_dtype) + self.ctkv_tensor = Tensor(np_ctkv, dtype=self.nope_ms_dtype) + self.k_rope_tensor = Tensor(np_k_rope, dtype=self.rope_ms_dtype) + + self.block_tables_tensor = Tensor(np.array(self.block_tables).astype(np.int32)) + + np_mask = self._build_np_mask() + self.mask_tensor = None if np_mask is None else Tensor(np_mask, dtype=self.rope_ms_dtype) + + if self.nope_ms_dtype == ms.int8: + self.deq_scale_qk_tensor = Tensor(np.random.uniform(-1.0, 1.0, size=(self.num_heads, )), dtype=ms.float32) + self.deq_scale_pv_tensor = Tensor(np.random.uniform(-1.0, 1.0, size=(self.num_heads, )), dtype=ms.float32) + else: + self.deq_scale_qk_tensor = None + self.deq_scale_pv_tensor = None + + self.q_seq_lens_tensor = Tensor(np_q_seq_lens) + self.context_lengths_tensor = Tensor(np_context_lens) + + +class Net(nn.Cell): + """Net""" + def __init__(self, q_head_num, kv_head_num, mask_type, tor): + super().__init__() + self.q_head_num = q_head_num + self.kv_head_num = kv_head_num + self.mask_type = mask_type + self.tor = tor + + def construct(self, q_nope, q_rope, ctkv, k_rope, block_tables, mask, deq_scale_qk, deq_scale_pv, + q_seq_lens, batch_valid_length): + return ops.auto_generate.mla(q_nope, q_rope, ctkv, k_rope, block_tables, mask, deq_scale_qk, + deq_scale_pv, q_seq_lens, batch_valid_length, self.q_head_num, self.tor, + self.kv_head_num, self.mask_type) + + +class GoldenNet(nn.Cell): + """GoldenNet""" + def __init__(self, q_head_num, kv_head_num, mask_mode, tor, mla_v_dim): + super().__init__() + self.q_head_num = q_head_num + self.kv_head_num = kv_head_num + self.mask_mode = mask_mode + self.tor = tor + self.mla_v_dim = mla_v_dim + self.op = PagedAttention(self.q_head_num, self.tor, self.kv_head_num, 'DEFAULT', 'MASK_DEFAULT', + self.mla_v_dim) + + def construct(self, query, key_cache, value_cache, block_tables, batch_valid_length, antiquant_scale, + antiquant_offset, attn_mask, q_seq_lens, alibi_mask): + return self.op(query, key_cache, value_cache, block_tables, batch_valid_length, antiquant_scale, + antiquant_offset, attn_mask, q_seq_lens, alibi_mask) + + +def run_mla(test_param : MlaTestParam): + """run mla""" + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + context.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + dyn_q_nope_shape = [None for _ in test_param.q_nope_tensor.shape] + dyn_q_nope_tensor = Tensor(shape=dyn_q_nope_shape, dtype=test_param.q_nope_tensor.dtype) + + net = Net(test_param.num_heads, test_param.kv_heads, test_param.mask_mode, test_param.tor) + net.set_inputs(q_nope=dyn_q_nope_tensor) + + out, _ = net(test_param.q_nope_tensor, test_param.q_rope_tensor, test_param.ctkv_tensor, test_param.k_rope_tensor, + test_param.block_tables_tensor, test_param.mask_tensor, test_param.deq_scale_qk_tensor, + test_param.deq_scale_pv_tensor, test_param.q_seq_lens_tensor, test_param.context_lengths_tensor) + return out + + +def run_golden(test_param : MlaTestParam): + """run_golden""" + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + context.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + + mla_v_dim = 512 + query = ops.reshape(ops.concat((test_param.q_nope_tensor, test_param.q_rope_tensor), axis=-1), + (test_param.num_tokens, 1, -1)) + key_cache = ops.concat((test_param.ctkv_tensor, test_param.k_rope_tensor), axis=-1) + dyn_q_shape = [None for _ in test_param.q_nope_tensor.shape] + dyn_q_nope_tensor = Tensor(shape=dyn_q_shape, dtype=test_param.q_nope_tensor.dtype) + golden_net = GoldenNet(test_param.num_heads, test_param.kv_heads, "MASK_DEFAULT", test_param.tor, mla_v_dim) + golden_net.set_inputs(query=dyn_q_nope_tensor) + + out_golden = golden_net(query, key_cache, key_cache, test_param.block_tables_tensor, + test_param.context_lengths_tensor, None, None, test_param.mask_tensor, + test_param.q_seq_lens_tensor, None) + + return out_golden + + +def run_test(test_param : MlaTestParam): + """run test""" + out_actual = run_mla(test_param) + out_golden = run_golden(test_param) + + assert np.allclose(out_actual.astype(ms.float32).asnumpy().reshape(-1), + out_golden.astype(ms.float32).asnumpy().reshape(-1), 0.001, 0.001) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('dtype', [ms.float16, ms.bfloat16]) +@pytest.mark.parametrize('mask_mode', ["MASK_NONE", "MASK_SPEC"]) +def test_mla_base(dtype, mask_mode): + """ + Feature: test mla + Description: test mla. + Expectation: the result is correct + """ + q_seq_lens = [1, 1, 1, 1] + context_lengths = [192, 193, 194, 195] + test_param = MlaTestParam(32, 1, 128, 512, 64, 1024, q_seq_lens, context_lengths, 0.001, dtype, dtype, mask_mode) + run_test(test_param) + + +# @pytest.mark.level0 +# @pytest.mark.platform_arm_ascend910b_training +# @pytest.mark.env_onecard +# @pytest.mark.parametrize('mask_mode', ["MASK_NONE", "MASK_SPEC"]) +# def test_mla_int8(mask_mode): +# """ +# Feature: test mla +# Description: test mla. +# Expectation: the result is correct +# """ +# q_seq_lens = [1, 1, 1, 1] +# context_lengths = [192, 193, 194, 195] +# test_param = MlaTestParam(32, 1, 128, 512, 64, 1024, q_seq_lens, context_lengths, 0.001, ms.int8, ms.bfloat16, mask_mode) +# run_test(test_param) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('block_size', [16, 32, 64, 128]) +@pytest.mark.parametrize('mask_mode', ["MASK_NONE", "MASK_SPEC"]) +def test_mla_block_size(block_size, mask_mode): + """ + Feature: test mla + Description: test mla. + Expectation: the result is correct + """ + q_seq_lens = [1, 1, 1, 1] + context_lengths = [192, 193, 194, 195] + test_param = MlaTestParam(32, 1, block_size, 512, 64, 1024, q_seq_lens, context_lengths, + 0.001, ms.float16, ms.float16, mask_mode) + run_test(test_param) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('dtype', [ms.bfloat16, ms.float16]) +@pytest.mark.parametrize('mask_mode', ["MASK_NONE"]) +@pytest.mark.parametrize('block_size', [16, 128]) +def test_mla_mtp(dtype, mask_mode, block_size): + """ + Feature: test mla + Description: test mla. + Expectation: the result is correct + """ + q_seq_lens = [1, 1, 2, 1] + context_lengths = [192, 193, 194, 195] + test_param = MlaTestParam(4, 1, block_size, 512, 64, 128, q_seq_lens, context_lengths, + 0.001, dtype, dtype, mask_mode) + run_test(test_param) diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 595cede9a73..ee43e5e9208 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -83,7 +83,7 @@ file(GLOB_RECURSE UT_BACKEND_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ./backend file(GLOB_RECURSE UT_GRAPH_KERNEL_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ./graph_kernel/*.cc) if(DEFINED ENV{MS_INTERNAL_KERNEL_HOME}) - file(GLOB_RECURSE UT_INTERNAL_KERNEL_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ./internal/*.cc + file(GLOB_RECURSE UT_INTERNAL_KERNEL_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ./internal/*.cc ./ops/utils/*.cc ./backend/graph_optimizer_test_framework.cc) endif() diff --git a/tests/ut/cpp/internal/test_ops_mla.cc b/tests/ut/cpp/internal/test_ops_mla.cc new file mode 100644 index 00000000000..26369d962b9 --- /dev/null +++ b/tests/ut/cpp/internal/test_ops_mla.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/utils/general_infer_utils.h" + +namespace mindspore::ops { +namespace { +std::vector prepare_params() { + GeneralInferParamGenerator generator; + generator + .FeedInputArgs({ + InferInfoParam{ShapeVector{4, 32, 512}, kNumberTypeFloat16}, // q_nope + InferInfoParam{ShapeVector{4, 32, 64}, kNumberTypeFloat16}, // q_rope + InferInfoParam{ShapeVector{1024, 128, 1, 512}, kNumberTypeFloat16}, // ctkv + InferInfoParam{ShapeVector{1024, 128, 1, 64}, kNumberTypeFloat16}, // k_rope + InferInfoParam{ShapeVector{4, 128}, kNumberTypeInt32}, // block_tables + InferInfoParam{ShapeVector{126, 512}, kNumberTypeInt32}, // mask + InferInfoParam{ShapeVector{4}, kNumberTypeFloat32}, // deq_scale_qk + InferInfoParam{ShapeVector{4}, kNumberTypeFloat32}, // deq_scale_pv + InferInfoParam{ShapeVector{4}, kNumberTypeInt32}, // q_seq_lens + InferInfoParam{ShapeVector{4}, kNumberTypeInt32}, // context_lengths + InferInfoParam{ShapeVector{}, kNumberTypeInt64, CreateScalar(32)}, // q_head_num + InferInfoParam{ShapeVector{}, kNumberTypeFloat32, CreateScalar(0.01)}, // scale_value + InferInfoParam{ShapeVector{}, kNumberTypeInt64, CreateScalar(1)}, // kv_head_num + InferInfoParam{ShapeVector{}, kNumberTypeInt64, CreateScalar(1)}, // mask_mode + InferInfoParam{ShapeVector{}, kNumberTypeInt64, CreateScalar(0)}, // is_ring + }) + .FeedExpectedOutput({{4, 32, 512} /* attention_out*/, {0} /* lse_out */}, {kNumberTypeFloat16, kNumberTypeFloat16}); + + generator + .FeedInputArgs({ + InferInfoParam{ShapeVector{-1, -1, -1}, kNumberTypeFloat16}, // q_nope + InferInfoParam{ShapeVector{-1, -1, -1}, kNumberTypeFloat16}, // q_rope + InferInfoParam{ShapeVector{-1, -1, -1, -1}, kNumberTypeFloat16}, // ctkv + InferInfoParam{ShapeVector{-1, -1, -1, -1}, kNumberTypeFloat16}, // k_rope + InferInfoParam{ShapeVector{-1, -1}, kNumberTypeInt32}, // block_tables + InferInfoParam{ShapeVector{-1, -1}, kNumberTypeInt32}, // mask + InferInfoParam{ShapeVector{-1}, kNumberTypeFloat32}, // deq_scale_qk + InferInfoParam{ShapeVector{-1}, kNumberTypeFloat32}, // deq_scale_pv + InferInfoParam{ShapeVector{-1}, kNumberTypeInt32}, // q_seq_lens + InferInfoParam{ShapeVector{-1}, kNumberTypeInt32}, // context_lengths + InferInfoParam{ShapeVector{}, kNumberTypeInt64, CreateScalar(32)}, // q_head_num + InferInfoParam{ShapeVector{}, kNumberTypeFloat32, CreateScalar(0.01)}, // scale_value + InferInfoParam{ShapeVector{}, kNumberTypeInt64, CreateScalar(1)}, // kv_head_num + InferInfoParam{ShapeVector{}, kNumberTypeInt64, CreateScalar(1)}, // mask_mode + InferInfoParam{ShapeVector{}, kNumberTypeInt64, CreateScalar(0)}, // is_ring + }) + .FeedExpectedOutput({{-1, -1, -1} /* attention_out*/, {0} /* lse_out */}, {kNumberTypeFloat16, kNumberTypeFloat16}); + + // MASK_FREE + generator + .FeedInputArgs({ + InferInfoParam{ShapeVector{4, 32, 512}, kNumberTypeFloat16}, // q_nope + InferInfoParam{ShapeVector{4, 32, 64}, kNumberTypeFloat16}, // q_rope + InferInfoParam{ShapeVector{1024, 128, 1, 512}, kNumberTypeFloat16}, // ctkv + InferInfoParam{ShapeVector{1024, 128, 1, 64}, kNumberTypeFloat16}, // k_rope + InferInfoParam{ShapeVector{4, 128}, kNumberTypeInt32}, // block_tables + InferInfoParam{ShapeVector{128, 128}, kNumberTypeInt32}, // mask + InferInfoParam{ShapeVector{4}, kNumberTypeFloat32}, // deq_scale_qk + InferInfoParam{ShapeVector{4}, kNumberTypeFloat32}, // deq_scale_pv + InferInfoParam{ShapeVector{4}, kNumberTypeInt32}, // q_seq_lens + InferInfoParam{ShapeVector{4}, kNumberTypeInt32}, // context_lengths + InferInfoParam{ShapeVector{}, kNumberTypeInt64, CreateScalar(32)}, // q_head_num + InferInfoParam{ShapeVector{}, kNumberTypeFloat32, CreateScalar(0.01)}, // scale_value + InferInfoParam{ShapeVector{}, kNumberTypeInt64, CreateScalar(1)}, // kv_head_num + InferInfoParam{ShapeVector{}, kNumberTypeInt64, CreateScalar(2)}, // mask_mode + InferInfoParam{ShapeVector{}, kNumberTypeInt64, CreateScalar(0)}, // is_ring + }) + .FeedExpectedOutput({{4, 32, 512} /* attention_out*/, {0} /* lse_out */}, {kNumberTypeFloat16, kNumberTypeFloat16}); + + return generator.Generate(); +} +} // namespace + +INSTANTIATE_TEST_CASE_P(Mla, GeneralInferTest, testing::ValuesIn(prepare_params())); +} // namespace mindspore::ops -- Gitee From 44e7a8c0f555843505da612941450f8e8ae5833e Mon Sep 17 00:00:00 2001 From: zhangshucheng Date: Tue, 20 May 2025 11:12:26 +0000 Subject: [PATCH 4/4] ms kernel __commit_id__ = [sha1]:2c6890d4,[branch]: (HEAD, mr-origin-371) Signed-off-by: zhangshucheng --- .../ascend/kernel/internal/kv_scale_cache.h | 2 +- .../aarch64/ms_kernels_internal.tar.gz | 4 +- .../x86_64/ms_kernels_internal.tar.gz | 4 +- .../ops/infer/ops_func_impl/kv_scale_cache.cc | 46 +++++++++------ .../mindspore/ops/operations/__init__.py | 4 +- .../python/mindspore/ops/operations/nn_ops.py | 4 +- .../ops/ascend/test_internal_ops/test_mla.py | 59 ++++++++++++------- tests/ut/cpp/ops/test_ops_kv_scale_cache.cc | 3 +- 8 files changed, 74 insertions(+), 52 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/kv_scale_cache.h b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/kv_scale_cache.h index 97fc8741a91..c8798649b7b 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/kv_scale_cache.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/kv_scale_cache.h @@ -27,4 +27,4 @@ namespace kernel { DECLARE_INTERNAL_KERNEL_MOD(KvScaleCache) } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_INTERNAL_KV_SCALE_CACHE_H_ \ No newline at end of file +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_INTERNAL_KV_SCALE_CACHE_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz index 76f0dcba0d8..f5d8a03bcb3 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:451d866c57a3feae9355cbd1d06a5fa05013bd1ef86c126a70e4ec40dfec909e -size 93074923 +oid sha256:873cc61dda40870780b46effdb9a9b0c1f582d0e02656ebe34d1775a5de8ffe7 +size 3689131 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz index d7b65635db0..506438d88e5 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f700baceb18a67f5256d96c2d122c7e11316f4c3a57fd136f930b4c8b284eeb9 -size 92517983 +oid sha256:ed71add8b1956f8f6780955b2bdfe75f1209b9a696a3d1da762f942a2e38fcf0 +size 3699824 diff --git a/mindspore/ops/infer/ops_func_impl/kv_scale_cache.cc b/mindspore/ops/infer/ops_func_impl/kv_scale_cache.cc index d460b4a4c29..bbff4e0d57a 100644 --- a/mindspore/ops/infer/ops_func_impl/kv_scale_cache.cc +++ b/mindspore/ops/infer/ops_func_impl/kv_scale_cache.cc @@ -27,7 +27,7 @@ namespace ops { namespace { static constexpr int32_t prefill_mode = 1; static constexpr int32_t incremental_mode = 0; -} +} // namespace BaseShapePtr KvScaleCacheFuncImpl::InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const { auto op_name = primitive->name(); @@ -52,34 +52,41 @@ BaseShapePtr KvScaleCacheFuncImpl::InferShape(const PrimitivePtr &primitive, const int64_t input_num_dims = 2; MS_CHECK_VALUE(key_scale_shape_ptr->GetShapeVector().size() == input_num_dims, CheckAndConvertUtils::FormatCommMsg("rank of kscale must be 2, but got:", - key_scale_shape_ptr->GetShapeVector().size())); + key_scale_shape_ptr->GetShapeVector().size())); MS_CHECK_VALUE(value_scale_shape_ptr->GetShapeVector().size() == input_num_dims, CheckAndConvertUtils::FormatCommMsg("rank of vscale must be 2, but got:", - value_scale_shape_ptr->GetShapeVector().size())); + value_scale_shape_ptr->GetShapeVector().size())); const size_t batch_valid_size = batch_valid_shape.size(); - (void)CheckAndConvertUtils::CheckInteger(batch_valid_size + "batch_valid_size must be greater than 0, but got:", batch_valid_size, kGreaterEqual, 0, - op_name); + (void)CheckAndConvertUtils::CheckInteger(batch_valid_size + "batch_valid_size must be greater than 0, but got:", + batch_valid_size, kGreaterEqual, 0, op_name); if (!IsDynamic(key_scale_cache_shape) && !IsDynamic(batch_valid_shape)) { const size_t key_scale_cache_dim = key_scale_cache_shape[0]; const size_t max_batch_size = key_scale_cache_shape[1]; // max_batch_size 约束 - MS_CHECK_VALUE(batch_valid_size <= max_batch_size, CheckAndConvertUtils::FormatCommMsg("The batch_size must not bigger than max_batch_size, but got batch_valid_size: ", batch_valid_size, ", max_batch_size: ", max_batch_size)); - MS_CHECK_VALUE(key_scale_cache_dim == input_num_dims, CheckAndConvertUtils::FormatCheckIntegerMsg("key_scale_cache_dim", SizeToLong(key_scale_cache_dim), kEqual, 2, primitive)); - MS_CHECK_VALUE(max_batch_size != 0, CheckAndConvertUtils::FormatCheckIntegerMsg("max_batch_size", SizeToLong(max_batch_size), kNotEqual, 0, primitive)); + MS_CHECK_VALUE(batch_valid_size <= max_batch_size, + CheckAndConvertUtils::FormatCommMsg( + "The batch_size must not bigger than max_batch_size, but got batch_valid_size: ", batch_valid_size, + ", max_batch_size: ", max_batch_size)); + MS_CHECK_VALUE(key_scale_cache_dim == input_num_dims, + CheckAndConvertUtils::FormatCheckIntegerMsg("key_scale_cache_dim", SizeToLong(key_scale_cache_dim), + kEqual, 2, primitive)); + MS_CHECK_VALUE(max_batch_size != 0, CheckAndConvertUtils::FormatCheckIntegerMsg( + "max_batch_size", SizeToLong(max_batch_size), kNotEqual, 0, primitive)); // max_seqlens约束 const size_t max_seqlens = key_scale_cache_shape[2]; - MS_CHECK_VALUE(max_seqlens != 0, CheckAndConvertUtils::FormatCheckIntegerMsg("max_seqlens", SizeToLong(max_seqlens), kNotEqual, 0, primitive)); + MS_CHECK_VALUE(max_seqlens != 0, CheckAndConvertUtils::FormatCheckIntegerMsg("max_seqlens", SizeToLong(max_seqlens), + kNotEqual, 0, primitive)); auto batch_valid_tensor = input_args[KvScaleCacheInputBatchVaildLengthIndex]; - //获取 batch_valid_length 的最大值 + // 获取 batch_valid_length 的最大值 if (batch_valid_tensor->GetValue() != nullptr) { auto shape_ptr = batch_valid_tensor->GetShape()->cast(); MS_EXCEPTION_IF_NULL(shape_ptr); const auto &shape = shape_ptr->shape(); auto max_value = *std::max_element(shape.begin(), shape.end()); - MS_CHECK_VALUE(max_value <= static_cast(max_seqlens), CheckAndConvertUtils::FormatCommMsg( - "Max seqlen in batch exceeds limit:", max_value, - " > max_seqlens:", max_seqlens)); + MS_CHECK_VALUE(static_cast(max_value) <= static_cast(max_seqlens), + CheckAndConvertUtils::FormatCommMsg("Max seqlen in batch exceeds limit:", max_value, + " > max_seqlens:", max_seqlens)); } } @@ -90,15 +97,16 @@ BaseShapePtr KvScaleCacheFuncImpl::InferShape(const PrimitivePtr &primitive, if (cache_mode_scalar.has_value()) { auto cache_mode = static_cast(cache_mode_scalar.value()); MS_LOG(INFO) << "cache_mode: " << cache_mode; - if (cache_mode != incremental_mode && cache_mode != prefill_mode && cache_mode != -1){ + if (cache_mode != incremental_mode && cache_mode != prefill_mode && cache_mode != -1) { MS_LOG(EXCEPTION) << "this cache_mode is not supported, but got cache_mode: " << cache_mode; } if (cache_mode == incremental_mode) { - MS_CHECK_VALUE((decode_batch >= batch_valid_size) && (seqlens == 1), - CheckAndConvertUtils::FormatCommMsg( - "For ", op_name, - ", decode_batch must be more than or equal to batch_valid_size, seqlens must be 1, but got decode_batch: ", decode_batch, ", batch_valid_size: ", batch_valid_size, "seqlens: ", seqlens) - ); + MS_CHECK_VALUE( + (decode_batch >= batch_valid_size) && (seqlens == 1), + CheckAndConvertUtils::FormatCommMsg( + "For ", op_name, + ", decode_batch must be more than or equal to batch_valid_size, seqlens must be 1, but got decode_batch: ", + decode_batch, ", batch_valid_size: ", batch_valid_size, "seqlens: ", seqlens)); } } diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index 82aa06eae36..e93bd8c3939 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -118,8 +118,8 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa Dilation2D, DataFormatVecPermute, DeformableOffsets, Dense, FractionalAvgPool, FractionalMaxPool, FractionalMaxPool3DWithFixedKsize, FractionalMaxPoolWithFixedKsize, GridSampler2D, TripletMarginLoss, UpsampleNearest3D, UpsampleTrilinear3D, PadV3, ChannelShuffle, - GLU, MaxUnpool3D, Pdist, RmsNorm, PagedAttention, PagedAttentionMask, ReshapeAndCache, KvScaleCache, - ApplyRotaryPosEmb, GroupTopk) + GLU, MaxUnpool3D, Pdist, RmsNorm, PagedAttention, PagedAttentionMask, ReshapeAndCache, + ApplyRotaryPosEmb, GroupTopk, KvScaleCache) from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, ConfusionMatrix, UpdateState, Load, StopGradient, Reusing, CheckValid, Partial, Depend, MoveTo, Push, Pull, PyExecute, PyFunc, _DynamicLossScale, diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index d483f910b49..1a029098444 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -37,9 +37,9 @@ from ..auto_generate import (CeLU, Flatten, LogSoftmax, LogSoftmaxExt, GLU, ReLU Elu, Sigmoid, Softmax, SoftplusExt, HSwish, HSigmoid, AvgPool, BiasAdd, NLLLoss, OneHot, GeLU, FastGeLU, PReLU, RmsNorm, IncreFlashAttention, MSELossExt, GridSampler3D, GridSampler2D, LayerNorm, LayerNormExt, HShrink, AdamWeightDecay, Dropout, - ApplyRotaryPosEmb, GroupTopk, PagedAttention, PagedAttentionMask, ReshapeAndCache, KvScaleCache, + ApplyRotaryPosEmb, GroupTopk, PagedAttention, PagedAttentionMask, ReshapeAndCache, FlashAttentionScore, PromptFlashAttention, Embedding, UpsampleNearest1D, UpsampleNearest2D, - DynamicNTK, + DynamicNTK, KvScaleCache, UpsampleNearest3D, UpsampleTrilinear3D, SoftMarginLoss, UpsampleBilinear2D, UpsampleLinear1D, BinaryCrossEntropy, BCEWithLogitsLoss, SoftShrink, AdaptiveMaxPool2D, diff --git a/tests/st/ops/ascend/test_internal_ops/test_mla.py b/tests/st/ops/ascend/test_internal_ops/test_mla.py index 55d2703d890..46ce5b82c29 100644 --- a/tests/st/ops/ascend/test_internal_ops/test_mla.py +++ b/tests/st/ops/ascend/test_internal_ops/test_mla.py @@ -24,8 +24,9 @@ import pytest class MlaTestParam: """MlaTestParam""" + def __init__(self, num_heads, kv_heads, block_size, head_size_nope, head_size_rope, num_blocks, - q_seq_lens : list, context_lengths : list, tor, nope_ms_dtype, rope_ms_dtype, mask_mode : str): + q_seq_lens: list, context_lengths: list, tor, nope_ms_dtype, rope_ms_dtype, mask_mode: str): self.num_heads = num_heads self.kv_heads = kv_heads @@ -44,14 +45,14 @@ class MlaTestParam: self.batch = len(q_seq_lens) self.max_context_len = max(context_lengths) - self.max_num_blocks_per_seq = (self.max_context_len + block_size - 1) // block_size + self.max_num_blocks_per_seq = ( + self.max_context_len + block_size - 1) // block_size self.num_tokens = (int)(np.array(q_seq_lens).sum()) self.block_tables = self._build_block_tables() self._build_tensor_inputs() - def _build_np_mask(self): """_build_np_mask""" if self.mask_mode == "MASK_NONE": @@ -59,14 +60,16 @@ class MlaTestParam: if self.mask_mode == "MASK_SPEC": pre_qseqlen = 0 - np_mask = np.zeros(shape=(self.num_tokens, self.max_context_len)).astype(np.float32) + np_mask = np.zeros( + shape=(self.num_tokens, self.max_context_len)).astype(np.float32) for i in range(self.batch): qseqlen = self.q_seq_lens[i] kseqlen = self.context_lengths[i] tri = np.ones((qseqlen, qseqlen)) tri = np.triu(tri, 1) tri *= self.mask_factor - np_mask[pre_qseqlen:(pre_qseqlen + qseqlen), kseqlen-qseqlen:kseqlen] = tri + np_mask[pre_qseqlen:(pre_qseqlen + qseqlen), + kseqlen-qseqlen:kseqlen] = tri pre_qseqlen += qseqlen return np_mask @@ -75,7 +78,6 @@ class MlaTestParam: return None - def _build_block_tables(self): """_build_block_tables""" block_tables_list = [] @@ -86,11 +88,12 @@ class MlaTestParam: block_tables_list.append(block_table) return block_tables_list - def _build_tensor_inputs(self): """_build_tensor_inputs""" - np_q_nope = np.random.uniform(-1.0, 1.0, size=(self.num_tokens, self.num_heads, self.head_size_nope)) - np_q_rope = np.random.uniform(-1.0, 1.0, size=(self.num_tokens, self.num_heads, self.head_size_rope)) + np_q_nope = np.random.uniform(-1.0, 1.0, size=( + self.num_tokens, self.num_heads, self.head_size_nope)) + np_q_rope = np.random.uniform(-1.0, 1.0, size=( + self.num_tokens, self.num_heads, self.head_size_rope)) np_ctkv = np.random.uniform(-1.0, 1.0, size=(self.num_blocks, self.block_size, self.kv_heads, self.head_size_nope)) np_k_rope = np.random.uniform(-1.0, 1.0, size=(self.num_blocks, self.block_size, @@ -104,14 +107,18 @@ class MlaTestParam: self.ctkv_tensor = Tensor(np_ctkv, dtype=self.nope_ms_dtype) self.k_rope_tensor = Tensor(np_k_rope, dtype=self.rope_ms_dtype) - self.block_tables_tensor = Tensor(np.array(self.block_tables).astype(np.int32)) + self.block_tables_tensor = Tensor( + np.array(self.block_tables).astype(np.int32)) np_mask = self._build_np_mask() - self.mask_tensor = None if np_mask is None else Tensor(np_mask, dtype=self.rope_ms_dtype) + self.mask_tensor = None if np_mask is None else Tensor( + np_mask, dtype=self.rope_ms_dtype) if self.nope_ms_dtype == ms.int8: - self.deq_scale_qk_tensor = Tensor(np.random.uniform(-1.0, 1.0, size=(self.num_heads, )), dtype=ms.float32) - self.deq_scale_pv_tensor = Tensor(np.random.uniform(-1.0, 1.0, size=(self.num_heads, )), dtype=ms.float32) + self.deq_scale_qk_tensor = Tensor( + np.random.uniform(-1.0, 1.0, size=(self.num_heads, )), dtype=ms.float32) + self.deq_scale_pv_tensor = Tensor( + np.random.uniform(-1.0, 1.0, size=(self.num_heads, )), dtype=ms.float32) else: self.deq_scale_qk_tensor = None self.deq_scale_pv_tensor = None @@ -122,6 +129,7 @@ class MlaTestParam: class Net(nn.Cell): """Net""" + def __init__(self, q_head_num, kv_head_num, mask_type, tor): super().__init__() self.q_head_num = q_head_num @@ -138,6 +146,7 @@ class Net(nn.Cell): class GoldenNet(nn.Cell): """GoldenNet""" + def __init__(self, q_head_num, kv_head_num, mask_mode, tor, mla_v_dim): super().__init__() self.q_head_num = q_head_num @@ -154,14 +163,16 @@ class GoldenNet(nn.Cell): antiquant_offset, attn_mask, q_seq_lens, alibi_mask) -def run_mla(test_param : MlaTestParam): +def run_mla(test_param: MlaTestParam): """run mla""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) dyn_q_nope_shape = [None for _ in test_param.q_nope_tensor.shape] - dyn_q_nope_tensor = Tensor(shape=dyn_q_nope_shape, dtype=test_param.q_nope_tensor.dtype) + dyn_q_nope_tensor = Tensor( + shape=dyn_q_nope_shape, dtype=test_param.q_nope_tensor.dtype) - net = Net(test_param.num_heads, test_param.kv_heads, test_param.mask_mode, test_param.tor) + net = Net(test_param.num_heads, test_param.kv_heads, + test_param.mask_mode, test_param.tor) net.set_inputs(q_nope=dyn_q_nope_tensor) out, _ = net(test_param.q_nope_tensor, test_param.q_rope_tensor, test_param.ctkv_tensor, test_param.k_rope_tensor, @@ -170,7 +181,7 @@ def run_mla(test_param : MlaTestParam): return out -def run_golden(test_param : MlaTestParam): +def run_golden(test_param: MlaTestParam): """run_golden""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) @@ -178,10 +189,13 @@ def run_golden(test_param : MlaTestParam): mla_v_dim = 512 query = ops.reshape(ops.concat((test_param.q_nope_tensor, test_param.q_rope_tensor), axis=-1), (test_param.num_tokens, 1, -1)) - key_cache = ops.concat((test_param.ctkv_tensor, test_param.k_rope_tensor), axis=-1) + key_cache = ops.concat( + (test_param.ctkv_tensor, test_param.k_rope_tensor), axis=-1) dyn_q_shape = [None for _ in test_param.q_nope_tensor.shape] - dyn_q_nope_tensor = Tensor(shape=dyn_q_shape, dtype=test_param.q_nope_tensor.dtype) - golden_net = GoldenNet(test_param.num_heads, test_param.kv_heads, "MASK_DEFAULT", test_param.tor, mla_v_dim) + dyn_q_nope_tensor = Tensor( + shape=dyn_q_shape, dtype=test_param.q_nope_tensor.dtype) + golden_net = GoldenNet(test_param.num_heads, test_param.kv_heads, + "MASK_DEFAULT", test_param.tor, mla_v_dim) golden_net.set_inputs(query=dyn_q_nope_tensor) out_golden = golden_net(query, key_cache, key_cache, test_param.block_tables_tensor, @@ -191,7 +205,7 @@ def run_golden(test_param : MlaTestParam): return out_golden -def run_test(test_param : MlaTestParam): +def run_test(test_param: MlaTestParam): """run test""" out_actual = run_mla(test_param) out_golden = run_golden(test_param) @@ -213,7 +227,8 @@ def test_mla_base(dtype, mask_mode): """ q_seq_lens = [1, 1, 1, 1] context_lengths = [192, 193, 194, 195] - test_param = MlaTestParam(32, 1, 128, 512, 64, 1024, q_seq_lens, context_lengths, 0.001, dtype, dtype, mask_mode) + test_param = MlaTestParam(32, 1, 128, 512, 64, 1024, q_seq_lens, + context_lengths, 0.001, dtype, dtype, mask_mode) run_test(test_param) diff --git a/tests/ut/cpp/ops/test_ops_kv_scale_cache.cc b/tests/ut/cpp/ops/test_ops_kv_scale_cache.cc index cbaf9389da2..79b3e26c7a3 100644 --- a/tests/ut/cpp/ops/test_ops_kv_scale_cache.cc +++ b/tests/ut/cpp/ops/test_ops_kv_scale_cache.cc @@ -71,7 +71,6 @@ INSTANTIATE_TEST_CASE_P( KvScaleCacheShapeParams{ {3, 4, 20}, kFloat32, {3, 4, 20}, kFloat32, {3, 4, 20}, kFloat32, {12}, kInt32, CreateScalar(1)}, KvScaleCacheShapeParams{ - {3, 4, 20}, kFloat32, {3, 4, 20}, kFloat32, {3, 4, 20}, kFloat32, {12}, kInt32, CreateScalar(0)} - )); + {3, 4, 20}, kFloat32, {3, 4, 20}, kFloat32, {3, 4, 20}, kFloat32, {12}, kInt32, CreateScalar(0)})); } // namespace ops } // namespace mindspore -- Gitee