From 500fc9919cd7ebf3cf21400949572dafcc317b7e Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Sat, 22 Feb 2025 17:49:11 +0800 Subject: [PATCH 01/14] support pa mtp debug: pa cache tiling force enable enable_lookahead for mtp pa mtp use q length --- .../plugin/device/ascend/kernel/internal/paged_attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/paged_attention.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/paged_attention.cc index e18e346b9ef..f821a410264 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/paged_attention.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/paged_attention.cc @@ -50,7 +50,7 @@ internal::InternalOpPtr InternalPagedAttention::CreateKernel(const internal::Inp has_attn_mask_ = (!(ms_inputs[kIndex7]->GetType()->isa())); has_alibi_mask_ = (!(ms_inputs[kIndex9]->GetType()->isa())); - (void)GetSeqLenFromGraphAndCheckUpadate(kernel_name_, {"q_seq_lens"}, ¶m_.q_seq_len); + param_.has_q_seq_lens = GetSeqLenFromGraphAndCheckUpadate(kernel_name_, {"q_seq_lens"}, ¶m_.q_seq_len); (void)GetSeqLenFromGraphAndCheckUpadate(kernel_name_, {"batch_valid_length"}, ¶m_.kv_seq_len); CheckMask(); @@ -84,7 +84,7 @@ bool InternalPagedAttention::UpdateParam(const std::vector &inpu uint64_t InternalPagedAttention::GenerateTilingKey(const std::vector &inputs) { // User defined CacheKey, the inputs should include all the factors which will affect tiling result. - return InternalTilingCache::GenerateKey(kernel_name_, inputs, param_.q_seq_len, param_.kv_seq_len); + return InternalTilingCache::GenerateKey(kernel_name_, inputs, param_.q_seq_len, param_.kv_seq_len, param_.has_q_seq_lens, param_.mla_v_dim); } MS_INTERNAL_KERNEL_FACTORY_REG(PagedAttention, internal::kInternalPagedAttentionOpName, InternalPagedAttention); -- Gitee From 61119fbe66b42a6dc8926407ac13137fdc683e28 Mon Sep 17 00:00:00 2001 From: shanfeng Date: Fri, 28 Feb 2025 10:15:51 +0800 Subject: [PATCH 02/14] Add empty cache --- .../include/backend/mem_reuse/dynamic_mem_pool.h | 2 ++ .../hal/hardware/ascend_device_res_manager.cc | 5 +++++ .../hal/hardware/ascend_device_res_manager.h | 2 ++ .../res_manager/ascend/ascend_res_manager.cc | 7 +++++++ .../res_manager/ascend/ascend_res_manager.h | 2 ++ .../abstract_ascend_memory_pool_support.cc | 9 +++++++++ .../abstract_ascend_memory_pool_support.h | 2 ++ .../ascend/mem_manager/ascend_memory_manager.cc | 1 + .../ascend/mem_manager/ascend_memory_manager.h | 1 + .../ascend/mem_manager/ascend_memory_pool.cc | 8 ++++++++ .../ascend/mem_manager/ascend_memory_pool.h | 2 ++ .../ascend/mem_manager/ascend_vmm_adapter.cc | 16 ++++++++++++++++ .../ascend/mem_manager/ascend_vmm_adapter.h | 2 ++ mindspore/ccsrc/pybind_api/hal/memory_py.cc | 12 ++++++++++++ .../runtime/device/res_manager/hal_res_base.h | 2 ++ .../runtime/device/res_manager/memory_manager.h | 1 + .../ccsrc/runtime/hardware/device_context.h | 2 ++ mindspore/python/mindspore/hal/memory.py | 7 ++++--- mindspore/python/mindspore/runtime/memory.py | 9 +++++++-- 19 files changed, 87 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/include/backend/mem_reuse/dynamic_mem_pool.h b/mindspore/ccsrc/include/backend/mem_reuse/dynamic_mem_pool.h index e5a42f4d0c0..4aab8af9ecf 100644 --- a/mindspore/ccsrc/include/backend/mem_reuse/dynamic_mem_pool.h +++ b/mindspore/ccsrc/include/backend/mem_reuse/dynamic_mem_pool.h @@ -130,6 +130,8 @@ class BACKEND_EXPORT DynamicMemPool { return {}; } + virtual size_t EmptyCache() { return -1L; } + // Element in vector : memory_stream_id, address virtual bool RecordEvent(int64_t task_id_on_stream, uint32_t user_stream_id, const std::vector> &memory_stream_addresses, diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.cc index 4ba86ed1ca0..2ba6de07663 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.cc @@ -199,6 +199,11 @@ void AscendDeviceResManager::ResetMaxMemoryAllocated() { return ascend_res_manager_->ResetMaxMemoryAllocated(); } +size_t AscendDeviceResManager::EmptyCache() { + MS_EXCEPTION_IF_NULL(ascend_res_manager_); + return ascend_res_manager_->EmptyCache(); +} + void AscendDeviceResManager::SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) { MS_EXCEPTION_IF_NULL(ascend_res_manager_); return ascend_res_manager_->SwapIn(host_ptr, device_ptr, mem_size, stream); diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.h b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.h index 3478961eb28..708f87ee9fb 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.h @@ -76,6 +76,8 @@ class AscendDeviceResManager : public DeviceResManager { const std::vector &keep_addr_sizes) const override; void DefragMemory() override; + size_t EmptyCache() override; + size_t GetMaxUsedMemorySize() const override; // Relevant function to manage memory statistics diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/ascend_res_manager.cc b/mindspore/ccsrc/plugin/res_manager/ascend/ascend_res_manager.cc index e1ed0dce542..9bfb663cc22 100644 --- a/mindspore/ccsrc/plugin/res_manager/ascend/ascend_res_manager.cc +++ b/mindspore/ccsrc/plugin/res_manager/ascend/ascend_res_manager.cc @@ -400,6 +400,13 @@ void AscendResManager::ResetMaxMemoryAllocated() { memory_pool->ResetMaxMemAllocated(); } +size_t AscendResManager::EmptyCache() { + MS_EXCEPTION_IF_NULL(mem_manager_); + auto memory_pool = mem_manager_->GetMemoryPool(); + MS_EXCEPTION_IF_NULL(memory_pool); + return memory_pool->EmptyCache(); +} + void AscendResManager::SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) { (void)mem_manager_->SwapIn(host_ptr, device_ptr, mem_size, stream); } diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/ascend_res_manager.h b/mindspore/ccsrc/plugin/res_manager/ascend/ascend_res_manager.h index 1ff57144f44..775c55553f2 100644 --- a/mindspore/ccsrc/plugin/res_manager/ascend/ascend_res_manager.h +++ b/mindspore/ccsrc/plugin/res_manager/ascend/ascend_res_manager.h @@ -100,6 +100,8 @@ class ASCEND_RES_MANAGER_EXPORT AscendResManager : public HalResBase { void ResetMaxMemoryReserved() override; void ResetMaxMemoryAllocated() override; + size_t EmptyCache() override; + void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override; void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override; diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/abstract_ascend_memory_pool_support.cc b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/abstract_ascend_memory_pool_support.cc index 0eabed42d95..e26546e09d2 100644 --- a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/abstract_ascend_memory_pool_support.cc +++ b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/abstract_ascend_memory_pool_support.cc @@ -181,6 +181,15 @@ size_t AbstractAscendMemoryPoolSupport::FreeDeviceMemByEagerFree(const DeviceMem } } +size_t AbstractAscendMemoryPoolSupport::EmptyCache() { + if (IsEnableVmm()) { + return AscendVmmAdapter::GetInstance().EmptyCache(); + } else { + MS_LOG(ERROR) << "Empty cache is not support as vmm is not enabled."; + } + return -1L; +} + size_t AbstractAscendMemoryPoolSupport::MmapDeviceMem(const size_t size, const DeviceMemPtr addr) { return AscendVmmAdapter::GetInstance().MmapDeviceMem(size, addr, total_mem_size()); } diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/abstract_ascend_memory_pool_support.h b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/abstract_ascend_memory_pool_support.h index 42bc8da170b..0643f3a3db5 100644 --- a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/abstract_ascend_memory_pool_support.h +++ b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/abstract_ascend_memory_pool_support.h @@ -60,6 +60,8 @@ class ASCEND_RES_MANAGER_EXPORT AbstractAscendMemoryPoolSupport : virtual public size_t AllocDeviceMemByEagerFree(size_t size, DeviceMemPtr *addr) override; size_t FreeDeviceMemByEagerFree(const DeviceMemPtr addr, const size_t size) override; + + size_t EmptyCache() override; }; using AbstractAscendMemoryPoolSupportPtr = std::shared_ptr; } // namespace ascend diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_manager.cc b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_manager.cc index bb84892137b..2e91732c4ea 100644 --- a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_manager.cc +++ b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_manager.cc @@ -90,6 +90,7 @@ AscendMemoryManager::GetPersistentMemBlocksInfoStatistics() const { } void AscendMemoryManager::ResetMaxMemoryReserved() { AscendMemoryPool::GetInstance().ResetMaxMemReserved(); } void AscendMemoryManager::ResetMaxMemoryAllocated() { AscendMemoryPool::GetInstance().ResetMaxMemAllocated(); } +size_t AscendMemoryManager::EmptyCache() { return AscendMemoryPool::GetInstance().EmptyCache(); } uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) { size_t align_size = 0; diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_manager.h b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_manager.h index 78fb889416b..a803f1f259f 100644 --- a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_manager.h +++ b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_manager.h @@ -68,6 +68,7 @@ class ASCEND_RES_MANAGER_EXPORT AscendMemoryManager : public MemoryManager { GetPersistentMemBlocksInfoStatistics() const override; void ResetMaxMemoryReserved() override; void ResetMaxMemoryAllocated() override; + size_t EmptyCache() override; DynamicMemPool *GetMemoryPool() override; diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.cc b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.cc index d6882888120..91926f830a6 100644 --- a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.cc +++ b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.cc @@ -71,6 +71,14 @@ DefaultAscendMemoryPool::DefaultAscendMemoryPool() { SetEnableVmm(AscendVmmAdapter::GetInstance().IsEnabled()); } +size_t DefaultAscendMemoryPool::EmptyCache() { + LockGuard lock(AbstractDynamicMemPool::lock()); + AbstractEnhancedDynamicMemPool::WaitPipelineHelper(); + AbstractAscendMemoryPoolSupport::SyncAllStreams(); + AbstractEnhancedDynamicMemPool::FreeIdleMemsByEagerFree(); + return AbstractAscendMemoryPoolSupport::EmptyCache(); +} + AscendMemoryTimeEvent::AscendMemoryTimeEvent(int32_t device_id, const MemoryTimeEventPtr &memory_time_event) : BaseReportData(device_id, static_cast(profiler::ascend::ReportFileType::MEMORY_USAGE)), memory_time_event_(memory_time_event) { diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.h b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.h index 111a19c243a..2e8446eeba8 100644 --- a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.h +++ b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.h @@ -54,6 +54,8 @@ class ASCEND_RES_MANAGER_EXPORT DefaultAscendMemoryPool : public AbstractAscendM } const bool IsEnableEagerFree() const override { return AbstractAscendMemoryPoolSupport::IsEnableEagerFree(); } + + size_t EmptyCache() override; }; using DefaultAscendMemoryPoolPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_vmm_adapter.cc b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_vmm_adapter.cc index f8d84087df7..f99cd7df452 100644 --- a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_vmm_adapter.cc +++ b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_vmm_adapter.cc @@ -263,6 +263,22 @@ size_t AscendVmmAdapter::EagerFreeDeviceMem(const DeviceMemPtr addr, const size_ << ", expected free size : " << size << ", real size : " << ret_size << "."; return ret_size; } +size_t AscendVmmAdapter::EmptyCache() { + size_t empty_size = 0L; + while (!cached_handle_sets_.empty()) { + auto handle = *cached_handle_sets_.begin(); + cached_handle_sets_.erase(cached_handle_sets_.begin()); + physical_handle_size_--; + auto ret = CALL_ASCEND_API(aclrtFreePhysical, handle); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Free physical memory failed."; + } else { + empty_size += kDefaultAlignSize; + } + } + MS_LOG(INFO) << "Empty cache size : " << empty_size << "."; + return empty_size; +} } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_vmm_adapter.h b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_vmm_adapter.h index 7208b130589..7813c906ab6 100644 --- a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_vmm_adapter.h +++ b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_vmm_adapter.h @@ -71,6 +71,8 @@ class ASCEND_RES_MANAGER_EXPORT AscendVmmAdapter { size_t EagerFreeDeviceMem(const DeviceMemPtr addr, const size_t size); size_t GetAllocatedSize() { return physical_handle_size_ * kVmmAlignSize; } + size_t EmptyCache(); + static const bool IsEnabled() { static bool is_enable_vmm = IsVmmEnabled(); return is_enable_vmm; diff --git a/mindspore/ccsrc/pybind_api/hal/memory_py.cc b/mindspore/ccsrc/pybind_api/hal/memory_py.cc index b1adeba60d6..092df6249fb 100644 --- a/mindspore/ccsrc/pybind_api/hal/memory_py.cc +++ b/mindspore/ccsrc/pybind_api/hal/memory_py.cc @@ -115,12 +115,24 @@ void ResetMaxMemoryAllocated(const std::string &device_target) { device_ctx->device_res_manager_->ResetMaxMemoryAllocated(); } +size_t EmptyCache(const std::string &device_target) { + runtime::Pipeline::Get().WaitAll(); + auto device_ctx = device::DeviceContextManager::GetInstance().GetDeviceContext(device_target); + if (device_ctx == nullptr) { + MS_LOG(INFO) << "Device context of device " << device_target << " is not created yet."; + return -1L; + } + + return device_ctx->device_res_manager_->EmptyCache(); +} + void RegMemory(py::module *m) { (void)m->def("_memory_stats", &mindspore::hal::MemoryStats, "Get memory pool's statistics."); (void)m->def("_reset_max_mem_reserved", &mindspore::hal::ResetMaxMemoryReserved, "Reset the maximum recorded memory reserved."); (void)m->def("_reset_max_mem_allocated", &mindspore::hal::ResetMaxMemoryAllocated, "Reset the maximum recorded memory allocated."); + (void)m->def("_empty_cache", &mindspore::hal::EmptyCache, "Empty memory pool cache."); } } // namespace hal } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/res_manager/hal_res_base.h b/mindspore/ccsrc/runtime/device/res_manager/hal_res_base.h index 7e86a99471d..2eb7c8027f4 100644 --- a/mindspore/ccsrc/runtime/device/res_manager/hal_res_base.h +++ b/mindspore/ccsrc/runtime/device/res_manager/hal_res_base.h @@ -107,6 +107,8 @@ class RES_EXPORT HalResBase { virtual void ResetMaxMemoryReserved() {} virtual void ResetMaxMemoryAllocated() {} + virtual size_t EmptyCache() { return -1L; } + // Allocate host memory with raii and ref count virtual std::shared_ptr AllocateHostMemory(size_t size) const { return std::shared_ptr(::malloc(size), ::free); diff --git a/mindspore/ccsrc/runtime/device/res_manager/memory_manager.h b/mindspore/ccsrc/runtime/device/res_manager/memory_manager.h index c5f130da283..b658662a9aa 100644 --- a/mindspore/ccsrc/runtime/device/res_manager/memory_manager.h +++ b/mindspore/ccsrc/runtime/device/res_manager/memory_manager.h @@ -132,6 +132,7 @@ class RES_EXPORT MemoryManager { } virtual void ResetMaxMemoryReserved() {} virtual void ResetMaxMemoryAllocated() {} + virtual size_t EmptyCache() { return -1L; } protected: virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) = 0; diff --git a/mindspore/ccsrc/runtime/hardware/device_context.h b/mindspore/ccsrc/runtime/hardware/device_context.h index a97922aed88..fce833f1013 100644 --- a/mindspore/ccsrc/runtime/hardware/device_context.h +++ b/mindspore/ccsrc/runtime/hardware/device_context.h @@ -238,6 +238,8 @@ class BACKEND_COMMON_EXPORT DeviceResManager { virtual void ResetMaxMemoryReserved() {} virtual void ResetMaxMemoryAllocated() {} + virtual size_t EmptyCache() { return -1L; } + // Allocate host memory with raii and ref count virtual std::shared_ptr AllocateHostMemory(size_t size) const { return std::shared_ptr(::malloc(size), ::free); diff --git a/mindspore/python/mindspore/hal/memory.py b/mindspore/python/mindspore/hal/memory.py index 9bf76b31b5e..43574aa42e4 100644 --- a/mindspore/python/mindspore/hal/memory.py +++ b/mindspore/python/mindspore/hal/memory.py @@ -14,7 +14,7 @@ # ============================================================================ """Hardware memory interfaces.""" -from mindspore._c_expression import _memory_stats, _reset_max_mem_reserved, _reset_max_mem_allocated +from mindspore._c_expression import _memory_stats, _reset_max_mem_reserved, _reset_max_mem_allocated, _empty_cache from mindspore import log as logger from .device import _check_inputs_validation, is_initialized @@ -146,7 +146,7 @@ def max_memory_reserved(device_target=None): @_check_inputs_validation -def empty_cache(): +def empty_cache(device_target=None): """ Release all memory fragments in the memory pool, so that memory arrangement will be optimized. @@ -160,7 +160,8 @@ def empty_cache(): """ if not function_memory_status['empty_cache']: function_memory_status['empty_cache'] = True - logger.warning(f"The empty_cache operation is currently not supported.") + logger.warning(f"The empty_cache operation is executing.") + return _empty_cache(device_target) @_check_inputs_validation diff --git a/mindspore/python/mindspore/runtime/memory.py b/mindspore/python/mindspore/runtime/memory.py index 823824e1dd7..48eaca1c9a1 100644 --- a/mindspore/python/mindspore/runtime/memory.py +++ b/mindspore/python/mindspore/runtime/memory.py @@ -16,7 +16,7 @@ """Memory interfaces.""" from mindspore._c_expression import RuntimeConf, DeviceManagerConf, _memory_stats, \ - _reset_max_mem_reserved, _reset_max_mem_allocated, DeviceContextManager + _reset_max_mem_reserved, _reset_max_mem_allocated, DeviceContextManager, _empty_cache from mindspore import _checkparam as Validator from mindspore._checkparam import args_type_check from mindspore import log as logger @@ -215,7 +215,12 @@ def empty_cache(): Currently, the MindSpore memory pool does not have the function of releasing memory fragments. This interface is reserved but implemented as an empty method and prompted in log mode when using. """ - logger.warning(f"The empty_cache operation is currently not supported.") + logger.info(f"The empty_cache operation is executing.") + device_target = ms.context.get_context("device_target") + if not _is_initialized(device_target): + logger.warning(f"Backend {device_target} is not initialized yet. Return 0.") + return 0 + return _empty_cache(device_target) def reset_peak_memory_stats(): -- Gitee From 63ab645d3d98fa0ef03cccc6adb8805990ea9de9 Mon Sep 17 00:00:00 2001 From: linux Date: Mon, 10 Mar 2025 20:23:31 +0800 Subject: [PATCH 03/14] for int4 quant --- .../ops_func_impl/grouped_matmul_base.cc | 16 +++- .../infer/ops_func_impl/grouped_matmul_base.h | 2 +- .../infer/ops_func_impl/grouped_matmul_v4.cc | 2 +- .../aclnn/grouped_matmul_v4_aclnn_kernel.cc | 29 ++++++- .../pyboost/customize/grouped_matmul_v4.cc | 22 +++++- tests/st/ops/test_ops_grouped_matmul_v4.py | 78 ++++++++++++++++++- 6 files changed, 138 insertions(+), 11 deletions(-) diff --git a/mindspore/ops/infer/ops_func_impl/grouped_matmul_base.cc b/mindspore/ops/infer/ops_func_impl/grouped_matmul_base.cc index 9648a74d9ab..0166a8ecc3e 100644 --- a/mindspore/ops/infer/ops_func_impl/grouped_matmul_base.cc +++ b/mindspore/ops/infer/ops_func_impl/grouped_matmul_base.cc @@ -105,7 +105,7 @@ void GroupedMatmulBaseFuncImpl::CheckInputAndWeightShapeForSingleOutput(const Pr ShapeArray GroupedMatmulBaseFuncImpl::InferShapeForSingleOutput(const PrimitivePtr &primitive, const ShapeArray &x_shapes, const ShapeArray &w_shapes, int64_t group_list_size, int64_t group_type, - bool transpose_b) const { + bool transpose_b, bool is_int4) const { if (MS_UNLIKELY(x_shapes.size() != kIndex1 || w_shapes.size() != kIndex1)) { MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', when split_item is 3. the size of x and weight should both be 1, but got x's size " @@ -119,6 +119,9 @@ ShapeArray GroupedMatmulBaseFuncImpl::InferShapeForSingleOutput(const PrimitiveP auto n = abstract::Shape::kShapeDimAny; if (!IsDynamicRank(w_shape)) { n = transpose_b ? w_shape[w_shape.size() - kInputIndex2] : w_shape.back(); + if (is_int4) { + n = n << 1; + } } std::vector res_shape; @@ -194,7 +197,16 @@ ShapeArray GroupedMatmulBaseFuncImpl::InferShape(const PrimitivePtr &primitive, } auto group_list_size = FetchGroupListSize(primitive, input_infos); auto transpose_b = GetTransposeValue(input_infos, idxes_.transpose_b); - return InferShapeForSingleOutput(primitive, x_shapes, w_shapes, group_list_size, group_type, transpose_b); + bool is_int4 = false; + if (MS_LIKELY(input_infos[idxes_.weight]->IsSequence())) { + const auto &w_tensors = input_infos[idxes_.weight]->GetSequenceElements(); + MS_ASSERT(w_tensors.size() > 0); + is_int4 = w_tensors[0]->GetType() == kNumberTypeInt4; + } else { + is_int4 = input_infos[idxes_.weight]->GetType() == kNumberTypeInt4; + } + + return InferShapeForSingleOutput(primitive, x_shapes, w_shapes, group_list_size, group_type, transpose_b, is_int4); } bool GroupedMatmulBaseFuncImpl::EnableInternal(const std::string &op_name) const { diff --git a/mindspore/ops/infer/ops_func_impl/grouped_matmul_base.h b/mindspore/ops/infer/ops_func_impl/grouped_matmul_base.h index 48ae470957b..81731a65593 100644 --- a/mindspore/ops/infer/ops_func_impl/grouped_matmul_base.h +++ b/mindspore/ops/infer/ops_func_impl/grouped_matmul_base.h @@ -77,7 +77,7 @@ class OPS_API GroupedMatmulBaseFuncImpl : public OpFuncImpl { ShapeArray InferShapeForSingleOutput(const PrimitivePtr &primitive, const ShapeArray &x_shapes, const ShapeArray &w_shapes, int64_t group_list_size, int64_t group_type, - bool transpose_b) const; + bool transpose_b, bool is_int4 = false) const; void CheckInputAndWeightShapeForMultiOutput(const PrimitivePtr &primitive, const ShapeVector &x_shape, const ShapeVector &w_shape, size_t i) const; diff --git a/mindspore/ops/infer/ops_func_impl/grouped_matmul_v4.cc b/mindspore/ops/infer/ops_func_impl/grouped_matmul_v4.cc index ee730fd197c..bf2753edaa9 100644 --- a/mindspore/ops/infer/ops_func_impl/grouped_matmul_v4.cc +++ b/mindspore/ops/infer/ops_func_impl/grouped_matmul_v4.cc @@ -82,7 +82,7 @@ TypeIdList GroupedMatmulV4FuncImpl::InferType(const PrimitivePtr &primitive, [](const InferInfoPtr &info) { return kNumberTypeFloat16; }); } else { MS_EXCEPTION(ValueError) << "For '" << primitive->name() - << "', the scale only support Uint16, BFloat16 and Float32."; + << "', the scale only support Uint16, BFloat16 and Float32, but got " << scale_type; } } return output_types; diff --git a/mindspore/ops/kernel/ascend/opapi/aclnn/grouped_matmul_v4_aclnn_kernel.cc b/mindspore/ops/kernel/ascend/opapi/aclnn/grouped_matmul_v4_aclnn_kernel.cc index c401b45d63b..32d279fe4e6 100644 --- a/mindspore/ops/kernel/ascend/opapi/aclnn/grouped_matmul_v4_aclnn_kernel.cc +++ b/mindspore/ops/kernel/ascend/opapi/aclnn/grouped_matmul_v4_aclnn_kernel.cc @@ -59,6 +59,23 @@ std::vector> DealWithGroupedMatmulListTensors(const } } // namespace +static inline void UnifyWeightShape(const std::vector &ori_weights, + std::vector> *new_weights_shared_ptr, + std::vector *new_weights_raw_ptr) { + for (const auto &w : ori_weights) { + if (w->dtype_id() == kNumberTypeInt4) { + auto new_w = std::make_shared(*w); + auto w_shape = w->GetShapeVector(); + w_shape.back() *= 2; + new_w->SetShapeVector(w_shape); + new_weights_shared_ptr->emplace_back(new_w); + new_weights_raw_ptr->emplace_back(new_w.get()); + } else { + new_weights_raw_ptr->emplace_back(w); + } + } +} + void GroupedMatmulV4Ascend::GetWorkSpaceInfo(const std::vector &inputs, const std::vector &outputs) { group_info_ = GetValue>(primitive_->GetAttr("group_info")); @@ -89,7 +106,11 @@ void GroupedMatmulV4Ascend::GetWorkSpaceInfo(const std::vector & MS_EXCEPTION_IF_NULL(act_type_tensor); act_type_ = act_type_tensor->GetValueWithCheck(); - GetWorkspaceForResize(list_inputs[kInputXIdx], list_inputs[kInputWeightIdx], list_inputs[kInputBiasIdx], + std::vector> new_weights; + std::vector new_weights_raw; + UnifyWeightShape(list_inputs[kInputWeightIdx], &new_weights, &new_weights_raw); + + GetWorkspaceForResize(list_inputs[kInputXIdx], new_weights_raw, list_inputs[kInputBiasIdx], list_inputs[kInputScaleIdx], list_inputs[kInputOffsetIdx], list_inputs[kInputAntiquantScaleIdx], list_inputs[kInputAntiquantOffsetIdx], list_inputs[kInputPreTokenScaleIdx], group_list_tensor, list_inputs[kInputActivationInputIdx], list_inputs[kInputActivationQuantScaleIdx], @@ -103,7 +124,11 @@ bool GroupedMatmulV4Ascend::Launch(const std::vector &inputs, MS_EXCEPTION_IF_NULL(stream_ptr); auto list_inputs = DealWithGroupedMatmulListTensors(group_info_, start_idxs_, inputs); auto group_list_tensor = *(inputs.begin() + start_idxs_[kInputGroupListIdx]); - RunOp(stream_ptr, workspace, list_inputs[kInputXIdx], list_inputs[kInputWeightIdx], list_inputs[kInputBiasIdx], + std::vector> new_weights; + std::vector new_weights_raw; + UnifyWeightShape(list_inputs[kInputWeightIdx], &new_weights, &new_weights_raw); + + RunOp(stream_ptr, workspace, list_inputs[kInputXIdx], new_weights_raw, list_inputs[kInputBiasIdx], list_inputs[kInputScaleIdx], list_inputs[kInputOffsetIdx], list_inputs[kInputAntiquantScaleIdx], list_inputs[kInputAntiquantOffsetIdx], list_inputs[kInputPreTokenScaleIdx], group_list_tensor, list_inputs[kInputActivationInputIdx], list_inputs[kInputActivationQuantScaleIdx], diff --git a/mindspore/ops/kernel/ascend/pyboost/customize/grouped_matmul_v4.cc b/mindspore/ops/kernel/ascend/pyboost/customize/grouped_matmul_v4.cc index a03d078c374..e51d5e1c745 100644 --- a/mindspore/ops/kernel/ascend/pyboost/customize/grouped_matmul_v4.cc +++ b/mindspore/ops/kernel/ascend/pyboost/customize/grouped_matmul_v4.cc @@ -28,7 +28,22 @@ namespace kernel { namespace pyboost { namespace { std::vector ConvertOptiaonlValueTupleToVector(const std::optional &tensor_list_opt); + +void UnifyWeightShape(const std::vector &ori_weights, std::vector *new_weights) { + for (const auto &ori_weight : ori_weights) { + if (ori_weight->data_type() == kNumberTypeInt4) { + auto new_weight = std::make_shared(*ori_weight); + auto ori_weight_shape = ori_weight->shape(); + ori_weight_shape.back() *= 2; + (void)new_weight->set_shape(ori_weight_shape); + (void)new_weights->emplace_back(new_weight); + } else { + (void)new_weights->emplace_back(ori_weight); + } + } +} } // namespace + void GroupedMatmulV4AscendCustomize( const std::shared_ptr &op, const ValueTuplePtr &x_tensor_list, const ValueTuplePtr &weight_tensor_list, const std::optional &bias_tensor_list, const std::optional &scale_tensor_list, @@ -75,9 +90,12 @@ void GroupedMatmulV4AscendCustomize( std::vector dyn_quant_scale_out; PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs()); + std::vector new_weights; + UnifyWeightShape(weight, &new_weights); + // Async PyBoostUtils::DispatchRun(std::make_shared( - [op, x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, pre_token_scale, group_list, + [op, x, new_weights, weight, bias, scale, offset, antiquant_scale, antiquant_offset, pre_token_scale, group_list, activation_input, activation_quant_scale, activation_quant_offset, split_item, group_type, group_list_type, act_type, activation_feature_out, dyn_quant_scale_out]() { auto device_context = op->device_context(); @@ -89,7 +107,7 @@ void GroupedMatmulV4AscendCustomize( // Malloc for output tensors PyBoostUtils::MallocOpOutputs(device_context, outputs); - LAUNCH_ACLNN(aclnnGroupedMatmulV4, device_context, op->stream_id(), x, weight, bias, scale, offset, + LAUNCH_ACLNN(aclnnGroupedMatmulV4, device_context, op->stream_id(), x, new_weights, bias, scale, offset, antiquant_scale, antiquant_offset, pre_token_scale, group_list, activation_input, activation_quant_scale, activation_quant_offset, split_item, group_type, group_list_type, act_type, outputs, activation_feature_out, dyn_quant_scale_out); diff --git a/tests/st/ops/test_ops_grouped_matmul_v4.py b/tests/st/ops/test_ops_grouped_matmul_v4.py index f225253e140..a76c9e6cf96 100644 --- a/tests/st/ops/test_ops_grouped_matmul_v4.py +++ b/tests/st/ops/test_ops_grouped_matmul_v4.py @@ -209,6 +209,76 @@ def test_grouped_matmul_v4_x2d_w3d_splititem3_grouptype0_a16w8(mode): np.testing.assert_allclose(except_np, res[0][:30].asnumpy(), rtol=1e-3) +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level0', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', ['KBK', 'pynative']) +def test_grouped_matmul_v4_x2d_w3d_splititem3_grouptype0_a16w4(mode): + """ + Feature: Test grouped_matmul + Description: semi_auto_parallel + Expectation: shape is as expected. + """ + context.set_context(device_target="Ascend") + if mode == 'KBK': + ms.set_context(mode=ms.GRAPH_MODE) + ms.set_context(jit_level='O0') + elif mode == 'pynative': + ms.set_context(mode=ms.PYNATIVE_MODE) + gmm_v4_net = GroupedMatmulV4Net() + + split_item = 3 + group_type = 0 + group_list_type = 0 + + M0 = 32 + K0 = 256 + N0 = 128 + E0 = 8 + group_list_np = [1, 3, 10, 14, 18, 22, 24, 30] # last value can be less than total token numbers + + # numpy calculate + np_x_all = np.random.uniform(-128, 127, size=[M0, K0]).astype(np.float16) + np_w_all = np.random.uniform(0, 2, size=[E0, K0, N0]).astype(np.int8) + antiquant_scale0 = np.array(np.full([E0, N0], 0.01)).astype(np.float16) + antiquant_offset0 = np.array(np.full([E0, N0], 1)).astype(np.float16) + + for i in range(E0): + for j in range(K0): + for k in range(N0): + np_w_all[i, j, k] = np_w_all[i, j, k] & 0xf + + np_w_all_int4 = np.ones((E0 * K0 * N0 // 2,), dtype=np.int8) + np_w_all_one_rank = np_w_all.reshape(-1,) + for i in range(E0 * K0 * N0 // 2): + np_w_all_int4[i] = np_w_all_one_rank[i * 2] | ((np_w_all_one_rank[(i * 2) + 1] & 15) << 4) + + np_w_all_int4_3_rank = np_w_all_int4.reshape((E0, K0, N0 // 2)) + + np_x = split_x(np_x_all, group_list_np) + np_w = split_w(np_w_all) + np_s = split_w(antiquant_scale0) + np_o = split_w(antiquant_offset0) + res_np = [np.matmul(x0, (w0 + o0) * s0) for x0, w0, s0, o0 in zip(np_x, np_w, np_s, np_o)] + expect_np = np.concatenate(res_np, axis=0) + + # ms calculate + x = [ms.Tensor(np_x_all)] + w = [ms.Tensor(np_w_all_int4_3_rank, dtype=ms.qint4x2)] + antiquant_scale = [ms.Tensor(antiquant_scale0)] + antiquant_offset = [ms.Tensor(antiquant_offset0)] + + b = None + scale = None + offset = None + pertoken_scale = None + group_list = ms.Tensor(group_list_np, dtype=mstype.int64) + + res = gmm_v4_net(x, w, b, scale, offset, antiquant_scale, antiquant_offset, pertoken_scale, group_list, + split_item, group_type, group_list_type) + + # compare + np.testing.assert_allclose(expect_np, res[0][:30].asnumpy(), rtol=1e-3, atol=1e-3) + + @arg_mark(plat_marks=['platform_ascend910b'], level_mark='level0', card_mark='onecard', essential_mark='unessential') @pytest.mark.parametrize('mode', ['KBK', 'pynative']) def test_grouped_matmul_v4_x2d_w3d_splititem3_grouptype0_none_pertoken(mode): @@ -296,25 +366,27 @@ def test_grouped_matmul_v4_x2d_w3d_splititem3_grouptype0_none_perchannel(mode): np_x_all = np.random.uniform(-128, 127, size=[M0, K0]).astype(np.int8) np_w_all = np.random.uniform(-128, 127, size=[E0, K0, N0]).astype(np.int8) np_s_all = np.array(np.full([E0, N0], 10)).astype(np.float32) + np_b_all = np.array(np.full([E0, N0], 1)).astype(np.float32) np_x = split_x(np_x_all, np.cumsum(group_list_np)) np_w = split_w(np_w_all) np_s = split_w(np_s_all) - res_np = [np.matmul(x0, w0 * s0) for x0, w0, s0 in zip(np_x, np_w, np_s)] + np_b = split_w(np_b_all) + res_np = [np.matmul(x0, w0 * s0) + b0 * s0 for x0, w0, s0, b0 in zip(np_x, np_w, np_s, np_b)] except_np = np.concatenate(res_np, axis=0) # ms calculate x = [ms.Tensor(np_x_all)] w = [ms.Tensor(np_w_all)] scale = [ms.Tensor(np_s_all, dtype=mstype.bfloat16)] + bias = [ms.Tensor(np_b, dtype=mstype.int32)] - b = None offset = None antiquant_scale = None antiquant_offset = None group_list = ms.Tensor(group_list_np, dtype=mstype.int64) - res = gmm_v4_net(x, w, b, scale, offset, antiquant_scale, antiquant_offset, None, group_list, + res = gmm_v4_net(x, w, bias, scale, offset, antiquant_scale, antiquant_offset, None, group_list, split_item, group_type, group_list_type) # compare -- Gitee From 2e240eb84b6035057cd8275e19134a60c742b956 Mon Sep 17 00:00:00 2001 From: yyyyrf Date: Thu, 13 Mar 2025 11:22:49 +0800 Subject: [PATCH 04/14] safetensor & initialize support qint4x2 --- mindspore/python/mindspore/common/tensor.py | 10 ++++++++-- .../python/mindspore/train/serialization.py | 18 ++++++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index 18f41369869..16c0a8fbe60 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -2165,11 +2165,13 @@ class Tensor(TensorPy_, metaclass=_TensorMeta): from mindspore.common.initializer import Zero as ZeroInitializer + is_qint4x2 = self.dtype == mstype.qint4x2 try: + dtype_ = mstype.int8 if is_qint4x2 else self.dtype if isinstance(self.init, ZeroInitializer): - data = np.zeros(data_shape, dtype=mstype.dtype_to_nptype(self.dtype)) + data = np.zeros(data_shape, dtype=mstype.dtype_to_nptype(dtype_)) else: - data = np.ndarray(data_shape, dtype=mstype.dtype_to_nptype(self.dtype)) + data = np.ndarray(data_shape, dtype=mstype.dtype_to_nptype(dtype_)) except ValueError as e: msg = "Error shape={}".format(shape) logger.critical(msg) @@ -2214,6 +2216,10 @@ class Tensor(TensorPy_, metaclass=_TensorMeta): self.assign_value(TensorPy_.persistent_data_from_numpy(data, slice_num_of_persistent_data)) else: self.assign_value(TensorPy_.from_numpy(data)) + + if is_qint4x2: + self.set_dtype(mstype.qint4x2) + return self def resize(self, *new_shape): diff --git a/mindspore/python/mindspore/train/serialization.py b/mindspore/python/mindspore/train/serialization.py index 6c5cd7f6159..0dfde5bf97c 100644 --- a/mindspore/python/mindspore/train/serialization.py +++ b/mindspore/python/mindspore/train/serialization.py @@ -96,6 +96,8 @@ mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: 5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16, 11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64} +safetensors_to_mstype = {'Int4': mstype.qint4x2} + _ckpt_mutex = RLock() # unit is KB @@ -425,11 +427,14 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_ elif format == "safetensors": save_dict = {} crc_num = 0 + meta_data = {} for name in sorted(data_list.keys()): value = data_list[name] if isinstance(value[2], np.ndarray): save_dict[name] = value[2] else: + if value[2].dtype == mstype.qint4x2: + meta_data[name] = str(mstype.qint4x2) save_dict[name] = value[2].asnumpy() if crc_check: @@ -438,10 +443,10 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_ bytes(save_dict[name]), crc_num) safetensors_save_time_start = time.time() if crc_check: - save_file(save_dict, tmp_name, metadata={ - "crc_num": str(crc_num)}) + meta_data.update({"crc_num": str(crc_num)}) + save_file(save_dict, tmp_name, metadata=meta_data) else: - save_file(save_dict, tmp_name) + save_file(save_dict, tmp_name, metadata=meta_data) safetensors_save_time_end = time.time() cost_time = safetensors_save_time_end - safetensors_save_time_start vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save safetensors io cost time:{cost_time}.") @@ -1227,7 +1232,12 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter io_end_time = time.time() io_cost_time = io_end_time - io_start_time total_io_cost_time += io_cost_time - parameter_dict[k] = Parameter(Tensor.from_numpy(value)) + if f.metadata() is not None and k in f.metadata().keys(): + sf_dtype = f.metadata()[k] + ms_dtype = safetensors_to_mstype[sf_dtype] + parameter_dict[k] = Parameter(Tensor(value, dtype=ms_dtype)) + else: + parameter_dict[k] = Parameter(Tensor.from_numpy(value)) vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load safetensors io cost time:{total_io_cost_time}.") -- Gitee From 80675358b257296958cdca70dac550dcb413dac8 Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Wed, 12 Mar 2025 23:28:19 +0800 Subject: [PATCH 05/14] update pa index --- .../plugin/device/ascend/kernel/internal/paged_attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/paged_attention.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/paged_attention.cc index f821a410264..8ad2347051a 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/paged_attention.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/paged_attention.cc @@ -88,8 +88,8 @@ uint64_t InternalPagedAttention::GenerateTilingKey(const std::vector Date: Tue, 4 Mar 2025 17:07:28 +0800 Subject: [PATCH 06/14] parallel dispatch kernel support comm ops --- .jenkins/check/config/whitelizard.txt | 1 + .../actor/super_kernel_actor.cc | 112 ++++++++++++++++-- .../actor/super_kernel_actor.h | 3 + 3 files changed, 108 insertions(+), 8 deletions(-) diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index fa30e7a987e..c191fbae9d8 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -31,6 +31,7 @@ mindspore/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.cc:mi mindspore/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc:mindspore::runtime::KernelActor::ExecuteLaunchKernelTask mindspore/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.cc:mindspore::runtime::SuperKernelActor::LaunchAllKernels mindspore/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc:mindspore::runtime::DataPrepareActor::PrepareDataForHostTensorQueueNew +mindspore/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.cc:mindspore::runtime::SuperKernelActor::PartitionParallelDispatchKernels mindspore/mindspore/ccsrc/pybind_api/init.cc:PYBIND11_MODULE mindspore/mindspore/ccsrc/pipeline/jit/ps/parse/resolve.cc:mindspore::parse::Resolver::ResolveObjectToNode mindspore/mindspore/ccsrc/pipeline/jit/ps/parse/parse.cc:mindspore::parse::Parser::ParseIf diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.cc index 236c303f746..c693dd98171 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.cc @@ -32,6 +32,7 @@ #include "op_def/framework_ops.h" #include "pybind_api/gil_scoped_long_running.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" +#include "include/backend/distributed/collective/collective_manager.h" namespace mindspore { namespace runtime { @@ -1587,6 +1588,7 @@ void SuperKernelActor::PartitionParallelDispatchKernels() { } // Get serial launch kernels. + static bool enable_multi_comm_group = common::IsEnableRuntimeConfig("communication_launch_group"); for (auto &kernel_actor : kernel_actors_) { if (!kernel_actor) { continue; @@ -1594,15 +1596,109 @@ void SuperKernelActor::PartitionParallelDispatchKernels() { auto &llm_manager = LLMManager::GetInstance(); const auto &kernel_name = kernel_actor->kernel_mod_->kernel_name(); bool need_force_resize = llm_manager.need_force_resize(kernel_name); - if (need_force_resize || (common::AnfAlgo::IsCommunicationOp(kernel_actor->kernel_) || - kernel_name == "QbmmAllReduceAdd" || kernel_name == "MatmulAllReduceAddRmsNorm")) { + if (need_force_resize || (kernel_name == kMatMulAllReduceOpName) || (kernel_name == "QbmmAllReduceAdd") || + (kernel_name == "MatmulAllReduceAddRmsNorm")) { serial_launch_kernels_.push_back(kernel_actor); - } else if (kernel_name.find(kAllReduceOpName) != std::string::npos || - kernel_name.find(kAllGatherOpName) != std::string::npos || - kernel_name.find(kReduceScatterOpName) != std::string::npos || - kernel_name.find(kAllToAllOpName) != std::string::npos || - kernel_name.find(kAlltoAllOpName) != std::string::npos) { - MS_LOG(WARNING) << "Find parallel dispatch communication op: " << kernel_name; + continue; + } + if (common::AnfAlgo::IsCommunicationOp(kernel_actor->kernel_)) { + if (!enable_multi_comm_group) { + serial_launch_kernels_.push_back(kernel_actor); + } + continue; + } + + if (kernel_name.find(kAllReduceOpName) != std::string::npos || + kernel_name.find(kAllGatherOpName) != std::string::npos || + kernel_name.find(kReduceScatterOpName) != std::string::npos || + kernel_name.find(kAllToAllOpName) != std::string::npos || + kernel_name.find(kAlltoAllOpName) != std::string::npos) { + MS_LOG(WARNING) << "Find not support parallel launch communication op: " << kernel_name; + serial_launch_kernels_.push_back(kernel_actor); + } + } + + if (enable_multi_comm_group) { + RecreateCommunicationGroup(); + } +} + +void SuperKernelActor::RecreateCommunicationGroup() { + std::vector> parallel_launch_comm_kernels(parallel_dispatch_num_); + HashSet group_set; + // 1. Collect communication ops. + for (size_t i = 0; i < parallel_dispatch_num_; i++) { + for (size_t j = 0; j < parallel_slice_num_; j++) { + auto &kernel_actors = parallel_launch_kernels_[i + j * parallel_dispatch_num_]; + for (auto &kernel_actor : kernel_actors) { + if (!kernel_actor) { + continue; + } + + const auto &kernel_name = kernel_actor->kernel_mod_->kernel_name(); + // MC2 kernels do not support multi communication group now. + bool is_naive_comm_op = common::AnfAlgo::IsCommunicationOp(kernel_actor->kernel_) && + (kernel_name != kMatMulAllReduceOpName) && (kernel_name != "QbmmAllReduceAdd") && + (kernel_name != "MatmulAllReduceAddRmsNorm"); + if (!is_naive_comm_op) { + continue; + } + + parallel_launch_comm_kernels[i].push_back(kernel_actor); + if (common::AnfAlgo::HasNodeAttr(kAttrGroup, kernel_actor->kernel_)) { + auto group_name = common::AnfAlgo::GetNodeAttr(kernel_actor->kernel_, kAttrGroup); + group_set.insert(group_name); + } else { + MS_LOG(EXCEPTION) << "Can not get communication group for kernel: " + << kernel_actor->kernel_->fullname_with_scope(); + } + } + } + } + + if (group_set.size() > 1) { + MS_LOG(WARNING) << "Communication ops parallel dispatch doesn't support multi communication group now, enable " + "parallel dispatch for communication ops with risk."; + } + std::string old_group_name = ""; + std::vector group_ranks = {}; + if (group_set.size() == 1) { + old_group_name = *group_set.begin(); + group_ranks = distributed::collective::CollectiveManager::instance()->GetGroupRanks(old_group_name); + MS_LOG(INFO) << "Old group name: " << old_group_name << ", group ranks: " << group_ranks; + } + + for (size_t i = 0; i < parallel_launch_comm_kernels.size(); i++) { + auto &comm_kernel_actors = parallel_launch_comm_kernels[i]; + if (comm_kernel_actors.empty()) { + continue; + } + // 2. New communication group. + const std::string new_group_name = std::string("parallel_dispatch_group_") + std::to_string(i); + distributed::collective::CollectiveManager::instance()->CreateCommunicationGroup(new_group_name, group_ranks); + + // 3. Repalce old communication group and re-init kernel mod for communication ops. + for (auto &kernel_actor : comm_kernel_actors) { + auto &kernel = kernel_actor->kernel_; + MS_EXCEPTION_IF_NULL(kernel_actor); + common::AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue(new_group_name), kernel); + + std::vector input_kernel_tensors = AnfAlgo::GetOrCreateAllInputKernelTensors(kernel); + std::vector output_kernel_tensors = AnfAlgo::GetOrCreateAllOutputKernelTensors(kernel); + + MS_LOG(INFO) << "Begin init kernel: " << kernel->fullname_with_scope(); + if (!kernel_actor->kernel_mod_->Init(common::AnfAlgo::GetCNodePrimitive(kernel), input_kernel_tensors, + output_kernel_tensors)) { + MS_LOG_WITH_NODE(EXCEPTION, kernel) + << "#dmsg#Kernel build failed:#dmsg#Initialize kernel op[" << kernel->fullname_with_scope() << "] failed."; + } + MS_LOG(INFO) << "End init kernel: " << kernel->fullname_with_scope(); + + if (kernel::CheckResizeCondition(kernel)) { + MS_LOG(INFO) << "Begin Resize kernel: " << kernel->fullname_with_scope(); + kernel_actor->kernel_mod_->Resize(input_kernel_tensors, output_kernel_tensors); + MS_LOG(INFO) << "End Resize kernel: " << kernel->fullname_with_scope(); + } } } } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.h index ea79059791a..6804cf9d919 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.h @@ -198,6 +198,9 @@ class SuperKernelActor : public DebugAwareActor { void InitParallelDispatchResource(); void PartitionParallelDispatchKernels(); + // Recreate the communication group for the communication operators, ensuring that the communication group is the + // same for the communication operators on each concurrent thread. + void RecreateCommunicationGroup(); void ClearParallelDispatchResource(); friend class GraphScheduler; -- Gitee From 79a1af27223fcf36dbb2a4a98fe08d0964ce458f Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Thu, 13 Mar 2025 17:38:22 +0800 Subject: [PATCH 07/14] clean code --- .../plugin/device/ascend/kernel/internal/paged_attention.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/paged_attention.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/paged_attention.cc index 8ad2347051a..9208958ba45 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/paged_attention.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/paged_attention.cc @@ -84,7 +84,8 @@ bool InternalPagedAttention::UpdateParam(const std::vector &inpu uint64_t InternalPagedAttention::GenerateTilingKey(const std::vector &inputs) { // User defined CacheKey, the inputs should include all the factors which will affect tiling result. - return InternalTilingCache::GenerateKey(kernel_name_, inputs, param_.q_seq_len, param_.kv_seq_len, param_.has_q_seq_lens, param_.mla_v_dim); + return InternalTilingCache::GenerateKey(kernel_name_, inputs, param_.q_seq_len, param_.kv_seq_len, + param_.has_q_seq_lens, param_.mla_v_dim); } MS_INTERNAL_KERNEL_FACTORY_REG(PagedAttention, internal::kInternalPagedAttentionOpName, InternalPagedAttention); -- Gitee From ae8d4a6e6fa7fa3461029adf83dac2790c8a6876 Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Sat, 15 Mar 2025 17:27:17 +0800 Subject: [PATCH 08/14] add group number check for parallel launch comm ops --- .../runtime/graph_scheduler/actor/super_kernel_actor.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.cc index c693dd98171..95cee2490fc 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/super_kernel_actor.cc @@ -1657,8 +1657,9 @@ void SuperKernelActor::RecreateCommunicationGroup() { } if (group_set.size() > 1) { - MS_LOG(WARNING) << "Communication ops parallel dispatch doesn't support multi communication group now, enable " - "parallel dispatch for communication ops with risk."; + MS_LOG(EXCEPTION) + << "Communication ops parallel dispatch doesn't support multi communication group now, please disable " + "parallel dispatch for communication ops by: export MS_DEV_RUNTIME_CONF='communication_launch_group:False'"; } std::string old_group_name = ""; std::vector group_ranks = {}; @@ -1666,8 +1667,12 @@ void SuperKernelActor::RecreateCommunicationGroup() { old_group_name = *group_set.begin(); group_ranks = distributed::collective::CollectiveManager::instance()->GetGroupRanks(old_group_name); MS_LOG(INFO) << "Old group name: " << old_group_name << ", group ranks: " << group_ranks; + } else { + MS_LOG(WARNING) << "There is no communication ops can parallel launch."; + return; } + MS_LOG(INFO) << "Enable parallel launch communication ops."; for (size_t i = 0; i < parallel_launch_comm_kernels.size(); i++) { auto &comm_kernel_actors = parallel_launch_comm_kernels[i]; if (comm_kernel_actors.empty()) { -- Gitee From afdd286fcc5dac5a97dfe4c83d9ef0ff937033e2 Mon Sep 17 00:00:00 2001 From: shanfeng Date: Sun, 16 Mar 2025 09:14:09 +0800 Subject: [PATCH 09/14] [bugfix] fix empty cache Signed-off-by: shanfeng --- .../memory/mem_pool/abstract_dynamic_mem_pool.cc | 5 +++++ .../ascend/mem_manager/ascend_memory_pool.cc | 13 ++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/memory/mem_pool/abstract_dynamic_mem_pool.cc b/mindspore/ccsrc/memory/mem_pool/abstract_dynamic_mem_pool.cc index b1e8f48d030..914249c6876 100644 --- a/mindspore/ccsrc/memory/mem_pool/abstract_dynamic_mem_pool.cc +++ b/mindspore/ccsrc/memory/mem_pool/abstract_dynamic_mem_pool.cc @@ -987,6 +987,11 @@ void AbstractDynamicMemPool::DumpDynamicMemPoolDebugInfo() { } const std::pair AbstractDynamicMemPool::FreeIdleMemsByEagerFree() { + if (!IsEnableVmm() && !IsEnableEagerFree()) { + MS_LOG(WARNING) << "FreeIdleMemsByEagerFree is not allowed since vmm is not enabled."; + return std::make_pair(0L, 0L); + } + MS_LOG(INFO) << "Free idle mems by eager free start, allocator size : " << stream_id_allocators_.size() << "."; eager_free_count_++; diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.cc b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.cc index 91926f830a6..7e7b4e5f340 100644 --- a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.cc +++ b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.cc @@ -72,11 +72,14 @@ DefaultAscendMemoryPool::DefaultAscendMemoryPool() { } size_t DefaultAscendMemoryPool::EmptyCache() { - LockGuard lock(AbstractDynamicMemPool::lock()); - AbstractEnhancedDynamicMemPool::WaitPipelineHelper(); - AbstractAscendMemoryPoolSupport::SyncAllStreams(); - AbstractEnhancedDynamicMemPool::FreeIdleMemsByEagerFree(); - return AbstractAscendMemoryPoolSupport::EmptyCache(); + if (IsEnableVmm() || IsEnableEagerFree()) { + LockGuard lock(AbstractDynamicMemPool::lock()); + AbstractEnhancedDynamicMemPool::WaitPipelineHelper(); + AbstractAscendMemoryPoolSupport::SyncAllStreams(); + AbstractEnhancedDynamicMemPool::FreeIdleMemsByEagerFree(); + return AbstractAscendMemoryPoolSupport::EmptyCache(); + } + return 0L; } AscendMemoryTimeEvent::AscendMemoryTimeEvent(int32_t device_id, const MemoryTimeEventPtr &memory_time_event) -- Gitee From 97279667e7f600253bd77b2e124803324f2ef3a4 Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Thu, 13 Mar 2025 21:50:18 +0800 Subject: [PATCH 10/14] update op lib: deepseek_20250313213442_6fdb123648f1c0753a8bb324b46744fac17f503e --- .../internal/prebuild/aarch64/ms_kernels_dependency.tar.gz | 4 ++-- .../internal/prebuild/aarch64/ms_kernels_internal.tar.gz | 4 ++-- .../internal/prebuild/x86_64/ms_kernels_dependency.tar.gz | 4 ++-- .../internal/prebuild/x86_64/ms_kernels_internal.tar.gz | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_dependency.tar.gz b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_dependency.tar.gz index 09f43168805..8fb6d525464 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_dependency.tar.gz +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_dependency.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1d9270cccbee79ca9e7e6371972a41b185f01f17f8e67572eba3ce08a1261a4c -size 156334771 +oid sha256:d2ec6ef1b7a16c942dd8866690ed22a9866963629d3444874abaeae5c3343e1f +size 158989788 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz index 0f8fb3f0166..cf90c0cf19a 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:414dd1228aebb6ab857f66d4644c842e43fcda6281e68491878d099618ae715f -size 3473076 +oid sha256:2cfe95c43541d3182a56a9179d8eee18c682fb865b69915d6cbdf789ae3ede45 +size 3524146 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_dependency.tar.gz b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_dependency.tar.gz index 3a11caa3320..1c52205c03b 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_dependency.tar.gz +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_dependency.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8baa40647b28b25ea58f8a52799683ac5359a28481c416bc133852c4dd5d264b -size 155270996 +oid sha256:db5185772e8480fee7965f285d53710640036a013f0a0afaab35ee870d4c5c78 +size 158251514 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz index 831eacdadf1..b4465472abf 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bb72b466a9deef0a9cd66b6b2c7ae05d9477e4cb100c4062cf4a841f65ee60b6 -size 3495553 +oid sha256:c4e534c301040ce5dc1d378c53aa201be6638e30b79154892032552ca30be1e9 +size 3534164 -- Gitee From 357a0734c610fd999823ef83f416c78b767d6287 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AD=99=E6=98=8A=E8=BE=B0?= Date: Fri, 14 Mar 2025 11:09:30 +0800 Subject: [PATCH 11/14] internal add asdop fused_add_topk_div kernel --- .../ccsrc/plugin/device/ascend/CMakeLists.txt | 1 + .../kernel/internal/fused_add_topk_div.cc | 58 ++++++++++++++++ .../kernel/internal/fused_add_topk_div.h | 31 +++++++++ .../aarch64/ms_kernels_internal.tar.gz | 4 +- .../x86_64/ms_kernels_internal.tar.gz | 4 +- mindspore/core/utils/ms_context.cc | 2 +- .../infer/ops_func_impl/fused_add_topk_div.cc | 67 +++++++++++++++++++ .../infer/ops_func_impl/fused_add_topk_div.h | 53 +++++++++++++++ .../op_def/yaml/fused_add_topk_div_op.yaml | 34 ++++++++++ 9 files changed, 249 insertions(+), 5 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/internal/fused_add_topk_div.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/internal/fused_add_topk_div.h create mode 100644 mindspore/ops/infer/ops_func_impl/fused_add_topk_div.cc create mode 100644 mindspore/ops/infer/ops_func_impl/fused_add_topk_div.h create mode 100644 mindspore/ops/op_def/yaml/fused_add_topk_div_op.yaml diff --git a/mindspore/ccsrc/plugin/device/ascend/CMakeLists.txt b/mindspore/ccsrc/plugin/device/ascend/CMakeLists.txt index 6ec54b0ea1a..63ee1523185 100644 --- a/mindspore/ccsrc/plugin/device/ascend/CMakeLists.txt +++ b/mindspore/ccsrc/plugin/device/ascend/CMakeLists.txt @@ -82,6 +82,7 @@ target_link_libraries(mindspore_ascend PRIVATE mindspore_backend_common mindspor mindspore_ops_kernel_common mindspore_ops_ascend mindspore_profiler mindspore_runtime_pipeline mindspore_ms_backend) target_link_libraries(mindspore_ascend PRIVATE mindspore_ascend_res_manager) +target_link_libraries(mindspore_ascend PRIVATE mindspore_res_manager) target_link_libraries(mindspore_ascend PRIVATE -Wl,--no-as-needed mindspore_pyboost -Wl,--as-needed) target_link_libraries(mindspore_ascend PRIVATE proto_input mindspore::protobuf) target_link_libraries(mindspore_ascend PRIVATE securec d_collective) diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/fused_add_topk_div.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/fused_add_topk_div.cc new file mode 100644 index 00000000000..be7d5a8b561 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/fused_add_topk_div.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/kernel/internal/fused_add_topk_div.h" + +#include +#include "kernel/kernel.h" +#include "plugin/device/ascend/kernel/internal/internal_kernel_in_out_map.h" + +namespace mindspore { +namespace kernel { +internal::InternalOpPtr InternalFusedAddTopKDiv::CreateKernel(const internal::InputsImmutableInfoList &inputs_ii, + const internal::OutputsImmutableInfoList &outputs_ii, + const std::vector &ms_inputs, + const std::vector &ms_outputs) { + internal::FusedAddTopkDivParam param; + auto group_num = ms_inputs.at(kIndex2); + auto group_topk = ms_inputs.at(kIndex3); + auto n = ms_inputs.at(kIndex4); + auto k = ms_inputs.at(kIndex5); + auto activate_type = ms_inputs.at(kIndex6); + auto is_norm = ms_inputs.at(kIndex7); + auto scale = ms_inputs.at(kIndex8); + if (group_num->dtype_id() == TypeId::kNumberTypeInt64 && group_topk->dtype_id() == TypeId::kNumberTypeInt64 && + n->dtype_id() == TypeId::kNumberTypeInt64 && k->dtype_id() == TypeId::kNumberTypeInt64 && + activate_type->dtype_id() == TypeId::kNumberTypeInt64 && is_norm->dtype_id() == TypeId::kNumberTypeBool && + scale->dtype_id() == TypeId::kNumberTypeFloat32) { + param.group_num = static_cast(group_num->GetValue().value()); + param.group_topk = static_cast(group_topk->GetValue().value()); + param.n = static_cast(n->GetValue().value()); + param.k = static_cast(k->GetValue().value()); + param.activate_type = static_cast(activate_type->GetValue().value()); + param.is_norm = is_norm->GetValue().value(); + param.scale = scale->GetValue().value(); + } else { + MS_LOG(EXCEPTION) << "FusedAddTopKDiv [group_num, group_topk, n, k, activate_type, is_norm, scale]'s dtype wrong"; + } + return internal::CreateFusedAddTopkDivOp(inputs_ii, outputs_ii, param, internal::kInternalFusedAddTopkDivOpName); +} + +MS_INTERNAL_KERNEL_FACTORY_REG(FusedAddTopKDiv, internal::kInternalFusedAddTopkDivOpName, InternalFusedAddTopKDiv); +REG_MS_TO_INTERNAL_IN_TENSOR_IDX_MAP(FusedAddTopKDiv, INPUT_NUM_2, INDEX_0, INDEX_1); +REG_MS_TO_INTERNAL_OUT_TENSOR_IDX_MAP(FusedAddTopKDiv, OUTPUT_NUM_2, INDEX_0, INDEX_1); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/fused_add_topk_div.h b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/fused_add_topk_div.h new file mode 100644 index 00000000000..8d195ffd406 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/fused_add_topk_div.h @@ -0,0 +1,31 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_INTERNAL_KERNEL_INTERNAL_FUSED_ADD_TOPK_DIV_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_INTERNAL_KERNEL_INTERNAL_FUSED_ADD_TOPK_DIV_H_ + +#include +#include +#include + +#include "plugin/device/ascend/kernel/internal/internal_kernel_mod.h" +#include "include/internal.h" + +namespace mindspore { +namespace kernel { +DECLARE_INTERNAL_KERNEL_MOD(FusedAddTopKDiv) +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_INTERNAL_KERNEL_INTERNAL_FUSED_ADD_TOPK_DIV_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz index cf90c0cf19a..bb03cbf1b2f 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2cfe95c43541d3182a56a9179d8eee18c682fb865b69915d6cbdf789ae3ede45 -size 3524146 +oid sha256:ada25502bc7d86d68ad9347761f2eeeb80f64ef7130a8b40b93e15d79ba855b4 +size 3608058 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz index b4465472abf..a80bc517dbf 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c4e534c301040ce5dc1d378c53aa201be6638e30b79154892032552ca30be1e9 -size 3534164 +oid sha256:4a881baaada79fa42b83b7647e6d4677f4b7bcf9062dc6062d64007978b0504e +size 3621100 diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index af9b14e96b9..1a31d057868 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -742,7 +742,7 @@ void MsContext::SetMsInternalEnableCustomKernelList() { const std::string kDefaultEnabledOpList = "MatMul,RmsNorm,Add,Sub,FlashAttentionScore,PagedAttention,PagedAttentionMask,AddRmsNorm,AddLayerNorm," "MatMulAllReduce,InferenceMatmulSplit,AddRmsNormQuantV2,InferenceSwiGLU,QbmmAllReduceAdd,QbmmAdd," - "AddRmsNormDynamicQuant,MatMulElemwise,RmsNormQuant,MatMulSigmoidCastAdd,TransposeBatchMatmulTranspose"; + "AddRmsNormDynamicQuant,MatMulElemwise,RmsNormQuant,MatMulSigmoidCastAdd,TransposeBatchMatmulTranspose,FusedAddTopKDiv"; const std::string k310pDefaultEnabledOpList = "MatMul,QuantBatchMatmul,QuantLinearSparse,QbmmAllReduceAdd,QbmmAdd"; auto internal_op_boost_env = common::GetEnv("MS_ENABLE_INTERNAL_BOOST"); bool is_enable_internal_op = true; diff --git a/mindspore/ops/infer/ops_func_impl/fused_add_topk_div.cc b/mindspore/ops/infer/ops_func_impl/fused_add_topk_div.cc new file mode 100644 index 00000000000..3538151c22e --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/fused_add_topk_div.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "infer/ops_func_impl/fused_add_topk_div.h" +#include +#include +#include +#include "abstract/ops/primitive_infer_map.h" +#include "mindspore/ops/op_def/nn_ops.h" +#include "utils/check_convert_utils.h" +#include "ops/primitive_c.h" +#include "mindapi/helper.h" +#include "include/api/data_type.h" + +namespace mindspore { +namespace ops { +BaseShapePtr FusedAddTopKDivFuncImpl::InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const { + auto op_name = primitive->name(); + auto ordinary_input_num = CheckAndConvertUtils::GetRemoveUMonadAbsNum(input_args); + (void)CheckAndConvertUtils::CheckInteger("inputs num", SizeToLong(ordinary_input_num), kEqual, + kFusedAddTopKDivInputsNum, op_name); + auto x_shape_ptr = input_args[kFusedAddTopKDivXIndex]->GetShape(); + if (MS_UNLIKELY(IsDynamicRank(x_shape_ptr->GetShapeVector()))) { + ShapeVector dyn_output{abstract::Shape::kShapeRankAny}; + return std::make_shared(std::move(dyn_output)); + } + + auto add_num_shape_ptr = input_args[kFusedAddTopKDivAddNumIndex]->GetShape(); + if (MS_UNLIKELY(IsDynamicRank(add_num_shape_ptr->GetShapeVector()))) { + ShapeVector dyn_output{abstract::Shape::kShapeRankAny}; + return std::make_shared(std::move(dyn_output)); + } + + auto a = x_shape_ptr->GetShapeVector()[0]; + auto k = GetScalarValue(input_args[kFusedAddTopKDivKIndex]->GetValue()); + if (MS_UNLIKELY(!k.has_value())) { + ShapeVector dyn_output{abstract::Shape::kShapeRankAny}; + return std::make_shared(std::move(dyn_output)); + } + // output_shape = {{a, param.k}, {a, param.k}} + ShapeVector weight_indices_shape{a, k.value()}; + auto output_shape = std::make_shared(weight_indices_shape); + return std::make_shared(abstract::BaseShapePtrList({output_shape, output_shape})); +} + +TypePtr FusedAddTopKDivFuncImpl::InferType(const PrimitivePtr &primitive, + const std::vector &input_args) const { + auto weight_out_type = std::make_shared(kFloat32); + auto indices_out_type = std::make_shared(kInt32); + return std::make_shared(std::vector{weight_out_type, indices_out_type}); +} +} // namespace ops +} // namespace mindspore diff --git a/mindspore/ops/infer/ops_func_impl/fused_add_topk_div.h b/mindspore/ops/infer/ops_func_impl/fused_add_topk_div.h new file mode 100644 index 00000000000..7f5274f33bd --- /dev/null +++ b/mindspore/ops/infer/ops_func_impl/fused_add_topk_div.h @@ -0,0 +1,53 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_FUSED_ADD_TOPK_DIV_H_ +#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_FUSED_ADD_TOPK_DIV_H_ +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "mindspore/ops/op_def/op_name.h" +#include "ops/ops_func_impl/op_func_impl.h" + +namespace mindspore { +namespace ops { +enum FusedAddTopKDivInputIndex : size_t { + kFusedAddTopKDivXIndex = 0, + kFusedAddTopKDivAddNumIndex, + kFusedAddTopKDivGroupNumIndex, + kFusedAddTopKDivGroupTopKIndex, + kFusedAddTopKDivNIndex, + kFusedAddTopKDivKIndex, + kFusedAddTopKDivActivateTypeIndex, + kFusedAddTopKDivIsNormIndex, + kFusedAddTopKDivScaleIndex, + kFusedAddTopKDivInputsNum, +}; + +class OPS_API FusedAddTopKDivFuncImpl : public OpFuncImpl { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const override; + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override; +}; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_FUSED_ADD_TOPK_DIV_H_ diff --git a/mindspore/ops/op_def/yaml/fused_add_topk_div_op.yaml b/mindspore/ops/op_def/yaml/fused_add_topk_div_op.yaml new file mode 100644 index 00000000000..ff1398d7f2b --- /dev/null +++ b/mindspore/ops/op_def/yaml/fused_add_topk_div_op.yaml @@ -0,0 +1,34 @@ +#operator FusedAddTopKDiv +fused_add_topk_div: + args: + x: + dtype: tensor + add_num: + dtype: tensor + type_cast: number + group_num: + dtype: int + group_topk: + dtype: int + n: + dtype: int + k: + dtype: int + activate_type: + dtype: int + default: 0 + is_norm: + dtype: bool + default: True + scale: + dtype: float + default: 2.5 + returns: + weight: + dtype: tensor + indices: + dtype: tensor + function: + name: fused_add_topk_div + class: + name: FusedAddTopKDiv -- Gitee From 4352d3273f6821980e067acee8d1fa004398284a Mon Sep 17 00:00:00 2001 From: ckey_Dou Date: Thu, 20 Mar 2025 21:25:25 +0800 Subject: [PATCH 12/14] rms_norm_quant_no_beta --- .../aarch64/ms_kernels_internal.tar.gz | 4 +- .../x86_64/ms_kernels_internal.tar.gz | 4 +- .../optimizer/backend_common_unify_mindir.cc | 1 + .../inference_weight_preprocess_utils.cc | 21 +- .../inference_weight_preprocess_utils.h | 2 +- .../ir_fusion_infer/rms_norm_quant_fusion.cc | 239 +++++++++++++----- .../ir_fusion_infer/rms_norm_quant_fusion.h | 25 +- 7 files changed, 215 insertions(+), 81 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz index bb03cbf1b2f..ff1fc9c7041 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ada25502bc7d86d68ad9347761f2eeeb80f64ef7130a8b40b93e15d79ba855b4 -size 3608058 +oid sha256:67330f0231f7f892ee024cd1fdc6f5cb1aae88a6558d1effb62ed30a0c9b4d82 +size 3609650 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz index a80bc517dbf..228aeaf7b61 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4a881baaada79fa42b83b7647e6d4677f4b7bcf9062dc6062d64007978b0504e -size 3621100 +oid sha256:a1beaace96b9ac2ba9c3cae9b450155d85219ccdbcae6197d24c01759140863d +size 3623078 diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/backend_common_unify_mindir.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/backend_common_unify_mindir.cc index 2bfd6ec27dd..a265821fbc2 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/backend_common_unify_mindir.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/backend_common_unify_mindir.cc @@ -181,6 +181,7 @@ PassManagerPtr GetBackendFusionGroupPassManager() { pm->AddFusionPass(std::make_shared(), infer_boost); pm->AddFusionPass(std::make_shared(), infer_boost); pm->AddFusionPass(std::make_shared(), infer_boost); + pm->AddFusionPass(std::make_shared(), infer_boost); pm->AddFusionPass(std::make_shared(), infer_boost); pm->AddFusionPass(std::make_shared(), infer_boost); pm->AddFusionPass(std::make_shared(), infer_boost); diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_weight_preprocess_utils.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_weight_preprocess_utils.cc index 6dead1142ff..b4c94f2ebbd 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_weight_preprocess_utils.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_weight_preprocess_utils.cc @@ -36,6 +36,16 @@ void ConvertDataType(void *dst_data, void *ori_data, int64_t len, bool need_rank dst_data_t[i] = static_cast(ori_data_t[i]); } } +float int32_to_float(std::int32_t int_value) { + union { + std::int32_t i; + float f; + } converter; + converter.i = int_value; + return converter.f; +} + +} // namespace std::shared_ptr CreateValueNode(const tensor::TensorPtr &assist_tensor, const TensorTypePtr &tensor_type) { MS_EXCEPTION_IF_NULL(assist_tensor); @@ -55,17 +65,6 @@ std::shared_ptr CreateValueNode(const tensor::TensorPtr &assist_tenso return assist_const; } -float int32_to_float(std::int32_t int_value) { - union { - std::int32_t i; - float f; - } converter; - converter.i = int_value; - return converter.f; -} - -} // namespace - std::shared_ptr ConvertWeightsToNewType(const AnfNodePtr &weight_node) { auto w_param = GetParamFromLoad(weight_node->cast(), true); MS_EXCEPTION_IF_NULL(w_param); diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_weight_preprocess_utils.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_weight_preprocess_utils.h index 43088158eff..480d6fe3a37 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_weight_preprocess_utils.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/inference_weight_preprocess_utils.h @@ -42,7 +42,7 @@ tensor::TensorPtr GetParamFromLoad(const CNodePtr &load, const bool unused); bool CheckFusionValid(const CNodePtr &matmul, int64_t *k, const int trans_a_pos, const int trans_b_pos, const std::vector &valid_dtypes); - +std::shared_ptr CreateValueNode(const tensor::TensorPtr &assist_tensor, const TensorTypePtr &tensor_type); std::shared_ptr CreateWeightTensor(TypeId type_id, const std::vector &weight_shape, const std::vector &data_c_list, const std::vector &n_len_list, const int64_t &k_len, diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/rms_norm_quant_fusion.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/rms_norm_quant_fusion.cc index 28e0028f299..742c1d80ff5 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/rms_norm_quant_fusion.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/rms_norm_quant_fusion.cc @@ -15,6 +15,7 @@ */ #include "plugin/device/ascend/optimizer/ir_fusion_infer/rms_norm_quant_fusion.h" +#include #include #include #include @@ -30,12 +31,21 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" +#include "plugin/device/ascend/optimizer/ir_fusion_infer/inference_weight_preprocess_utils.h" namespace mindspore { namespace opt { -std::vector RmsNormQuantFusion::MustExistPrimitiveName() const { - std::vector ret{prim::kPrimRmsNorm->name(), prim::kPrimAdd->name(), prim::kPrimQuantV2->name()}; - return ret; +template +std::shared_ptr CreateZeroTensor(const ShapeVector &gamma_shape, TypeId gamma_type) { + tensor::TensorPtr assist_tensor = std::make_shared(gamma_type, gamma_shape); + TensorTypePtr tensor_type = std::make_shared(TypeIdToType(gamma_type)); + T *dst_data_t = reinterpret_cast(assist_tensor->data_c()); + const auto data_size = sizeof(T); + auto set_ret = memset_s(dst_data_t, gamma_shape[0] * data_size, 0, gamma_shape[0] * data_size); + if (set_ret != EOK) { + MS_LOG(EXCEPTION) << "Failed to set tensor to zeros."; + } + return CreateValueNode(assist_tensor, tensor_type); } inline bool IsZero(const BaseRef &n) { @@ -55,20 +65,6 @@ inline bool IsZero(const BaseRef &n) { return false; } -const BaseRef RmsNormQuantFusion::DefinePattern() const { - auto index0 = std::make_shared(IsConstant); - auto rms_norm = VectorRef({prim::kPrimRmsNorm, x1_, gamma_, eps_}); - - auto tuple_get_item_0 = VectorRef({prim::kPrimTupleGetItem, rms_norm, index0}); - auto add = VectorRef({prim::kPrimAdd, tuple_get_item_0, beta0_}); - - auto sqrt_mode0 = std::make_shared(IsConstant); // not used - auto rounding_mode0 = std::make_shared(IsConstant); // not used - auto dst_type0 = std::make_shared(IsConstant); // not used - auto quant = VectorRef({prim::kPrimQuantV2, add, scale0_, offset0_, sqrt_mode0, rounding_mode0, dst_type0}); - return quant; -} - static bool IsSupport(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &rms_norm) { auto x_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(rms_norm, 0); auto gamma_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(rms_norm, 1); @@ -133,6 +129,144 @@ static bool IsSupport(const FuncGraphPtr &graph, const AnfNodePtr &node, const A return true; } +static const AnfNodePtr CreateRmsNormQuantNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &x1, + const AnfNodePtr &gamma, const AnfNodePtr &beta, const AnfNodePtr &scale, + const AnfNodePtr &offset, const AnfNodePtr &eps) { + auto prim = std::make_shared("RmsNormQuant"); + std::vector inputs = {NewValueNode(prim), x1, gamma, beta, scale, offset, eps}; + auto rms_norm_quant = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(rms_norm_quant); + + std::vector types; + std::vector shapes; + auto output_num = AnfAlgo::GetOutputElementNum(node); + for (size_t i = 0; i < output_num; i++) { + types.push_back(common::AnfAlgo::GetOutputInferDataType(node, i)); + shapes.push_back(AnfAlgo::GetOutputDetailShape(node, i)); + } + + common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, rms_norm_quant.get()); + rms_norm_quant->set_scope(node->scope()); + + auto build_info = GenerateKernelBuildInfo(rms_norm_quant); + AnfAlgo::SetSelectKernelBuildInfo(build_info, rms_norm_quant.get()); + + return rms_norm_quant; +} + +std::vector RmsNormQuantFusion::MustExistPrimitiveName() const { + std::vector ret{prim::kPrimRmsNorm->name(), prim::kPrimQuantV2->name()}; + return ret; +} + +const BaseRef RmsNormQuantFusion::DefinePattern() const { + auto index0 = std::make_shared(IsConstant); + auto rms_norm = VectorRef({prim::kPrimRmsNorm, x1_, gamma_, eps_}); + + auto tuple_get_item_0 = VectorRef({prim::kPrimTupleGetItem, rms_norm, index0}); + + auto sqrt_mode0 = std::make_shared(IsConstant); // not used + auto rounding_mode0 = std::make_shared(IsConstant); // not used + auto dst_type0 = std::make_shared(IsConstant); // not used + auto quant = + VectorRef({prim::kPrimQuantV2, tuple_get_item_0, scale0_, offset0_, sqrt_mode0, rounding_mode0, dst_type0}); + return quant; +} + +const AnfNodePtr RmsNormQuantFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (!ms_context->IsEnableInferBoost()) { + MS_LOG(INFO) << "Internal op is disabled."; + return nullptr; + } + + const std::string fusion_op_name = "RmsNormQuant"; + auto enable_op_list = ms_context->ms_internal_enable_custom_kernel_list(); + bool enable_add_rmsnorm = + (std::find(enable_op_list.begin(), enable_op_list.end(), fusion_op_name) != enable_op_list.end()); + if (!enable_add_rmsnorm) { + MS_LOG(INFO) << "Internal RmsNormQuant is disabled."; + return nullptr; + } + + auto rms_norm_out0 = common::AnfAlgo::GetInputNode(utils::cast(node), 0); + auto rms_norm_node = common::AnfAlgo::GetInputNode(utils::cast(rms_norm_out0), 0); + MS_EXCEPTION_IF_NULL(rms_norm_node); + + if (!IsSupport(graph, node, rms_norm_node)) { + MS_LOG(INFO) << "Can't fused to RmsNormQuant because of unsupported case."; + return nullptr; + } + + auto rms_norm_out0_users = GetRealNodeUsedList(graph, rms_norm_out0); + if (rms_norm_out0_users->size() > 1) { + MS_LOG(INFO) << "RmsNormQuant fused failed because the number of users of rms_norm_out0 is more than 1: " + << rms_norm_out0_users->size(); + return nullptr; + } + + auto x1 = utils::cast((*equiv)[x1_]); + auto gamma = utils::cast((*equiv)[gamma_]); + auto scale = utils::cast((*equiv)[scale0_]); + auto offset = utils::cast((*equiv)[offset0_]); + auto eps = utils::cast((*equiv)[eps_]); + + auto kernel_graph = graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + + TypeId gamma_type = common::AnfAlgo::GetOutputInferDataType(gamma, 0); + auto gamma_shape = common::AnfAlgo::GetOutputInferShape(gamma, kIndex0); + if (gamma_shape.size() != 1) { + MS_LOG(INFO) << "gamma_shape.size():" << gamma_shape.size() << " != 1."; + return nullptr; + } + + ValueNodePtr beta; + if (gamma_type == kNumberTypeFloat16) { + beta = CreateZeroTensor(gamma_shape, gamma_type); + } else if (gamma_type == kNumberTypeBFloat16) { + beta = CreateZeroTensor(gamma_shape, gamma_type); + } else { + MS_LOG(INFO) << "gamma_type:" << TypeIdToString(gamma_type) << " != kNumberTypeFloat16 && != kNumberTypeBFloat16."; + return nullptr; + } + if (!beta) { + MS_LOG(INFO) << "beta is nullptr."; + return nullptr; + } + kernel_graph->AddValueNodeToGraph(beta); + + auto rms_norm_quant = CreateRmsNormQuantNode(graph, node, x1, gamma, beta, scale, offset, eps); + if (rms_norm_quant != nullptr) { + MS_LOG(INFO) << "RmsNormQuant fused successfully."; + } else { + MS_LOG(INFO) << "RmsNormQuant fused failed."; + } + + return rms_norm_quant; +} + +std::vector RmsNormAddQuantFusion::MustExistPrimitiveName() const { + std::vector ret{prim::kPrimRmsNorm->name(), prim::kPrimAdd->name(), prim::kPrimQuantV2->name()}; + return ret; +} + +const BaseRef RmsNormAddQuantFusion::DefinePattern() const { + auto index0 = std::make_shared(IsConstant); + auto rms_norm = VectorRef({prim::kPrimRmsNorm, x1_, gamma_, eps_}); + + auto tuple_get_item_0 = VectorRef({prim::kPrimTupleGetItem, rms_norm, index0}); + auto add = VectorRef({prim::kPrimAdd, tuple_get_item_0, beta0_}); + + auto sqrt_mode0 = std::make_shared(IsConstant); // not used + auto rounding_mode0 = std::make_shared(IsConstant); // not used + auto dst_type0 = std::make_shared(IsConstant); // not used + auto quant = VectorRef({prim::kPrimQuantV2, add, scale0_, offset0_, sqrt_mode0, rounding_mode0, dst_type0}); + return quant; +} + static constexpr auto kRmsNormOut2OneAddQuant = 1; static constexpr auto kRmsNormOut2TwoAddQuant = 2; static constexpr auto kRmsNormOut2OneAddQuantAndOneShape = 3; @@ -149,12 +283,12 @@ void GetAddAndShapeNum(const FuncGraphPtr &graph, if (IsPrimitiveCNode(user_node, prim::kPrimAdd)) { auto add_users = GetRealNodeUsedList(graph, user_node); if (add_users->size() != 1) { - MS_LOG(INFO) << "RmsNormQuant fuse failed because the user of Add is more than one: " << add_users->size(); + MS_LOG(INFO) << "RmsNormAddQuant fuse failed because the user of Add is more than one: " << add_users->size(); return; } if (!IsPrimitiveCNode(add_users->at(0).first, prim::kPrimQuantV2)) { - MS_LOG(INFO) << "RmsNormQuant fuse failed because the user of Add is not Quant: " + MS_LOG(INFO) << "RmsNormAddQuant fuse failed because the user of Add is not Quant: " << add_users->at(0).first->fullname_with_scope(); return; } @@ -178,7 +312,7 @@ inline size_t GetOpsCaseAfterRmsNorm(const FuncGraphPtr &graph, const AnfNodePtr if (user_num == 1) { if (add_num != 1) { - MS_LOG(INFO) << "RmsNormQuant fuse failed because the user of RmsNorm is not Add-Quant"; + MS_LOG(INFO) << "RmsNormAddQuant fuse failed because the user of RmsNorm is not Add-Quant"; return kUnsupportedTag; } @@ -195,7 +329,7 @@ inline size_t GetOpsCaseAfterRmsNorm(const FuncGraphPtr &graph, const AnfNodePtr } MS_LOG(INFO) - << "RmsNormQuant fuse failed because the num of Add and shape in users of RmsNorm is invalid, add_num: " + << "RmsNormAddQuant fuse failed because the num of Add and shape in users of RmsNorm is invalid, add_num: " << add_num << ", shape_num: " << shape_num; return kUnsupportedTag; } @@ -206,7 +340,7 @@ inline size_t GetOpsCaseAfterRmsNorm(const FuncGraphPtr &graph, const AnfNodePtr } MS_LOG(INFO) - << "RmsNormQuant fuse failed because the num of Add and shape in users of RmsNorm is invalid, add_num: " + << "RmsNormAddQuant fuse failed because the num of Add and shape in users of RmsNorm is invalid, add_num: " << add_num << ", shape_num: " << shape_num; return kUnsupportedTag; } @@ -214,34 +348,9 @@ inline size_t GetOpsCaseAfterRmsNorm(const FuncGraphPtr &graph, const AnfNodePtr return kUnsupportedTag; } -static const AnfNodePtr CreateRmsNormQuantNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &x1, - const AnfNodePtr &gamma, const AnfNodePtr &beta, const AnfNodePtr &scale, - const AnfNodePtr &offset, const AnfNodePtr &eps) { - auto prim = std::make_shared("RmsNormQuant"); - std::vector inputs = {NewValueNode(prim), x1, gamma, beta, scale, offset, eps}; - auto rms_norm_quant = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(rms_norm_quant); - - std::vector types; - std::vector shapes; - auto output_num = AnfAlgo::GetOutputElementNum(node); - for (size_t i = 0; i < output_num; i++) { - types.push_back(common::AnfAlgo::GetOutputInferDataType(node, i)); - shapes.push_back(AnfAlgo::GetOutputDetailShape(node, i)); - } - - common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, rms_norm_quant.get()); - rms_norm_quant->set_scope(node->scope()); - - auto build_info = GenerateKernelBuildInfo(rms_norm_quant); - AnfAlgo::SetSelectKernelBuildInfo(build_info, rms_norm_quant.get()); - - return rms_norm_quant; -} - -const AnfNodePtr RmsNormQuantFusion::RmsNormQuantFuseWithOnePath(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv, - const AnfNodePtr &shape_node) const { +const AnfNodePtr RmsNormAddQuantFusion::RmsNormQuantFuseWithOnePath(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv, + const AnfNodePtr &shape_node) const { auto x1 = utils::cast((*equiv)[x1_]); auto gamma = utils::cast((*equiv)[gamma_]); auto beta = utils::cast((*equiv)[beta0_]); @@ -253,7 +362,7 @@ const AnfNodePtr RmsNormQuantFusion::RmsNormQuantFuseWithOnePath(const FuncGraph if (shape_node != nullptr) { shape_input_node = common::AnfAlgo::GetInputNode(utils::cast(shape_node), 0); if (shape_input_node == nullptr) { - MS_LOG(INFO) << "RmsNormQuant fused failed because shape_input_node is nullptr"; + MS_LOG(INFO) << "RmsNormAddQuant fused failed because shape_input_node is nullptr"; return nullptr; } } @@ -364,9 +473,10 @@ inline bool ParameterNotEqual(const std::string &name, const AnfNodePtr &load0, return ValueNotEqual(data_c0, data_c1, size0); } -const AnfNodePtr RmsNormQuantFusion::RmsNormQuantFuseWithTwoPath(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv, const AnfNodePtr &rms_norm_out0, - const AnfNodePtr &shape_node) const { +const AnfNodePtr RmsNormAddQuantFusion::RmsNormQuantFuseWithTwoPath(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv, + const AnfNodePtr &rms_norm_out0, + const AnfNodePtr &shape_node) const { auto x1 = utils::cast((*equiv)[x1_]); auto gamma = utils::cast((*equiv)[gamma_]); auto beta0_load = utils::cast((*equiv)[beta0_]); @@ -378,7 +488,7 @@ const AnfNodePtr RmsNormQuantFusion::RmsNormQuantFuseWithTwoPath(const FuncGraph if (shape_node != nullptr) { shape_input_node = common::AnfAlgo::GetInputNode(utils::cast(shape_node), 0); if (shape_input_node == nullptr) { - MS_LOG(INFO) << "RmsNormQuant fused failed because shape_input_node is nullptr"; + MS_LOG(INFO) << "RmsNormAddQuant fused failed because shape_input_node is nullptr"; return nullptr; } } @@ -404,8 +514,9 @@ const AnfNodePtr RmsNormQuantFusion::RmsNormQuantFuseWithTwoPath(const FuncGraph const auto add_node = user.first; auto load = common::AnfAlgo::GetInputNode(utils::cast(add_node), 1); if (!IsPrimitiveCNode(load, prim::kPrimLoad)) { - MS_LOG(INFO) << "RmsNormQuant fuse failed because the input node is not load when add-quant number is 2, input: " - << load->DebugString(); + MS_LOG(INFO) + << "RmsNormAddQuant fuse failed because the input node is not load when add-quant number is 2, input: " + << load->DebugString(); return nullptr; } @@ -419,7 +530,7 @@ const AnfNodePtr RmsNormQuantFusion::RmsNormQuantFuseWithTwoPath(const FuncGraph } if (ParameterNotEqual("beta", beta0_load, beta1_load)) { - MS_LOG(INFO) << "RmsNormQuant fuse failed because the value of beta is not equal."; + MS_LOG(INFO) << "RmsNormAddQuant fuse failed because the value of beta is not equal."; return nullptr; } @@ -435,12 +546,12 @@ const AnfNodePtr RmsNormQuantFusion::RmsNormQuantFuseWithTwoPath(const FuncGraph auto offset1 = common::AnfAlgo::GetInputNode(utils::cast(quant_node1), kOffsetIdx); if (ParameterNotEqual("scale", scale0, scale1)) { - MS_LOG(INFO) << "RmsNormQuant fuse failed because the value of scale is not equal."; + MS_LOG(INFO) << "RmsNormAddQuant fuse failed because the value of scale is not equal."; return nullptr; } if (ParameterNotEqual("offset", offset0, offset1)) { - MS_LOG(INFO) << "RmsNormQuant fuse failed because the value of offset is not equal."; + MS_LOG(INFO) << "RmsNormAddQuant fuse failed because the value of offset is not equal."; return nullptr; } @@ -456,8 +567,8 @@ const AnfNodePtr RmsNormQuantFusion::RmsNormQuantFuseWithTwoPath(const FuncGraph return rms_norm_quant_node; } -const AnfNodePtr RmsNormQuantFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { +const AnfNodePtr RmsNormAddQuantFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); if (!ms_context->IsEnableInferBoost()) { @@ -480,7 +591,7 @@ const AnfNodePtr RmsNormQuantFusion::Process(const FuncGraphPtr &graph, const An MS_EXCEPTION_IF_NULL(rms_norm_node); if (!IsSupport(graph, node, rms_norm_node)) { - MS_LOG(INFO) << "Can't fused to RmsNormQuant because of unsupported case."; + MS_LOG(INFO) << "Can't fused to RmsNormAddQuant because of unsupported case."; return nullptr; } @@ -496,7 +607,7 @@ const AnfNodePtr RmsNormQuantFusion::Process(const FuncGraphPtr &graph, const An } if (out_node != nullptr) { - MS_LOG(INFO) << "RmsNormQuant fused successfully with RmsNorm out case: " << num_of_add_after_rmsnorm; + MS_LOG(INFO) << "RmsNormAddQuant fused successfully with RmsNorm out case: " << num_of_add_after_rmsnorm; } return out_node; diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/rms_norm_quant_fusion.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/rms_norm_quant_fusion.h index d3a8928bf9e..f2768a79a43 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/rms_norm_quant_fusion.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion_infer/rms_norm_quant_fusion.h @@ -29,7 +29,6 @@ class RmsNormQuantFusion : public PatternProcessPass { x1_ = std::make_shared(); gamma_ = std::make_shared(); eps_ = std::make_shared(); - beta0_ = std::make_shared(); scale0_ = std::make_shared(); offset0_ = std::make_shared(); } @@ -37,6 +36,30 @@ class RmsNormQuantFusion : public PatternProcessPass { const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + private: + std::vector MustExistPrimitiveName() const override; + + VarPtr x1_; + VarPtr gamma_; + VarPtr eps_; + VarPtr scale0_; + VarPtr offset0_; +}; + +class RmsNormAddQuantFusion : public PatternProcessPass { + public: + explicit RmsNormAddQuantFusion(bool multigraph = true) : PatternProcessPass("rms_norm_add_quant_fusion", multigraph) { + x1_ = std::make_shared(); + gamma_ = std::make_shared(); + eps_ = std::make_shared(); + beta0_ = std::make_shared(); + scale0_ = std::make_shared(); + offset0_ = std::make_shared(); + } + ~RmsNormAddQuantFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + private: std::vector MustExistPrimitiveName() const override; const AnfNodePtr RmsNormQuantFuseWithOnePath(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &, -- Gitee From 5357489f4f9fe20bbdee8240a1cda0b7fcc493da Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Fri, 14 Mar 2025 14:35:58 +0800 Subject: [PATCH 13/14] Multi input supports batch launch. --- .../actor/data_prepare_actor.cc | 58 ++++- .../actor/data_prepare_actor.h | 1 + .../actor/data_source_actor.cc | 237 +++++++++++------- .../graph_scheduler/actor/data_source_actor.h | 21 +- tests/st/runtime/test_parallel_dispatch.py | 50 +++- 5 files changed, 249 insertions(+), 118 deletions(-) diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc index 0b9e2045304..ebfb8fdd57a 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc @@ -484,13 +484,25 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no } tensor_address->set_new_ref_count(SIZE_MAX); + static const bool enable_infer_boost = MsContext::GetInstance()->IsEnableInferBoost(); + bool is_kv_cache = enable_infer_boost && (input_tensor->name().find("key_cache") != std::string::npos || + input_tensor->name().find("value_cache") != std::string::npos); + auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false); MS_EXCEPTION_IF_NULL(device_address); if (tensor_address == device_address) { + if (is_kv_cache) { + MS_LOG(EXCEPTION) << "The tensor address can not set into input node for kv cache: " + << input_node->fullname_with_scope(); + } tensor_address->SetNodeIndex(input_node, 0); tensor_address->set_original_ref_count(SIZE_MAX); tensor_address->ResetRefCount(); return; + } else if (is_kv_cache && tensor_address->pointer_ref_count() == device_address->pointer_ref_count()) { + MS_LOG(EXCEPTION) << "The tensor address can not set into input node for kv cache: " + << input_node->fullname_with_scope() << " tensor address:" << tensor_address->PrintInfo() + << " input address:" << device_address->PrintInfo(); } // If tensor address and device address are different (heterogeneous scenarios), or device address is persisted @@ -506,7 +518,17 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no (void)address_modified_input_nodes_.insert(input_node.get()); tensor_address->set_flag(device_address->flag()); DeviceAddressUtils::UpdateDeviceAddressHostInfoByNode(tensor_address, input_node, 0); - AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get()); + if (is_kv_cache) { + const auto &kernel_tensor = device_address->kernel_tensor(); + MS_EXCEPTION_IF_NULL(kernel_tensor); + kernel_tensor->set_device_ptr(tensor_address->GetMutablePtr()); + MS_EXCEPTION_IF_NULL(kernel_tensor->device_ptr()); + device_address->set_from_mem_pool(false); + device_address->set_new_ref_count(SIZE_MAX); + device_address->set_original_ref_count(SIZE_MAX); + } else { + AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get()); + } MS_LOG(DEBUG) << "Update device address of " << input_node->DebugString() << " to " << tensor_address.get() << ", kernel tensor addr:" << tensor_address->kernel_tensor().get() << " ptr:" << tensor_address->GetPtr(); @@ -630,7 +652,10 @@ void DataPrepareActor::PrepareData(const std::vector> &in return; } MS_EXCEPTION_IF_NULL(graph_compiler_info_); - if (!address_modified_input_nodes_.empty()) { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + static const bool enable_infer_boost = ms_context->IsEnableInferBoost(); + if (!address_modified_input_nodes_.empty() && !enable_infer_boost) { UpdateDeviceAddressByRefInputNode(graph_compiler_info_->graphs_, address_modified_input_nodes_); address_modified_input_nodes_.clear(); } @@ -885,6 +910,9 @@ void DataPrepareActor::RecordGraphInputs(const std::vector &host_tens auto &llm_manager = LLMManager::GetInstance(); for (size_t i = 0; i < host_tensors.size(); ++i) { auto host_tensor = host_tensors[i]; + if (host_tensor == nullptr) { + continue; + } auto param_index = host_param_indexes[i]; const auto &origin_parameter = graph_compiler_info_->origin_parameters_order_[param_index]; // host_tensor must not be nullptr @@ -902,6 +930,11 @@ void DataPrepareActor::PrepareDataForHostTensorQueueNew(const VectorRef &args, O host_tensors_.resize(host_data_size); host_param_indexes.resize(host_data_size); bool isDyn = false; + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + static const bool enable_infer_boost = ms_context->IsEnableInferBoost(); + bool first_kv_cache_input = true; + bool early_stop_prepare = false; // Fill host tensors. for (size_t i = 0; i < graph_compiler_info_->origin_parameters_order_.size(); ++i) { if (current_data_num == host_data_size) { @@ -942,11 +975,26 @@ void DataPrepareActor::PrepareDataForHostTensorQueueNew(const VectorRef &args, O MS_LOG(INFO) << "Set host tensor position:" << tensor_position << " for input parameter:" << origin_parameter->fullname_with_scope(); + if (enable_infer_boost && first_kv_cache_input && + (input_tensor->name().find("key_cache") != std::string::npos || + input_tensor->name().find("value_cache") != std::string::npos)) { + first_kv_cache_input = false; + + bool kv_cache_not_change = (input_tensor->shape() == kv_cache_shape_); + kv_cache_shape_ = input_tensor->shape(); + if (kv_cache_not_change) { + host_data_source_actor_->set_is_shape_match(true); + early_stop_prepare = true; + break; + } + } + if (!isDyn) { if (host_tensors_[tensor_position] != input_tensor->shape() || input_tensor->shape().empty()) { isDyn = true; } } + host_tensors_[tensor_position] = input_tensor->shape(); host_tensors[tensor_position] = input_tensor; host_param_indexes[tensor_position] = i; @@ -961,9 +1009,13 @@ void DataPrepareActor::PrepareDataForHostTensorQueueNew(const VectorRef &args, O UpdateDeviceAddressForDataNode(origin_to_backend_pair.second.first, input_tensor); } } + if (early_stop_prepare) { + MS_LOG(DEBUG) << "Early stop prepare in index:" << i << " parameter:" << origin_parameter->DebugString(); + break; + } } - if (is_enable_infer_boost_ && EnableKbkSubGraphExecute()) { + if (enable_infer_boost && EnableKbkSubGraphExecute()) { RecordGraphInputs(host_tensors, host_param_indexes); if (has_dynamic_shape_) { ActorDispatcher::set_enable_static_shape(!isDyn); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.h index d84dd6e7cb0..c39113d5cfe 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.h @@ -163,6 +163,7 @@ class DataPrepareActor : public DebugAwareActor { std::set address_modified_input_nodes_; bool first_step_; std::vector host_tensors_; + ShapeVector kv_cache_shape_; bool has_parameter_input_; // The tensor of parameter(weight) maybe update host value by Python phase and need re-prepare to sync new host value diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_source_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_source_actor.cc index e0b76ff1e72..375c5338948 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_source_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_source_actor.cc @@ -15,6 +15,9 @@ */ #include "runtime/graph_scheduler/actor/data_source_actor.h" + +#include + #include "runtime/graph_scheduler/actor/kernel_actor.h" #include "runtime/graph_scheduler/actor/memory_manager_actor.h" #include "runtime/graph_scheduler/actor/output_actor.h" @@ -43,16 +46,9 @@ void DataSourceActor::FetchData(OpContext *const context) { MS_LOG(INFO) << "Data source actor(" << GetAID().Name() << ") fetches data."; MS_EXCEPTION_IF_NULL(context); device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, GetAID().Name(), GetAID().Name(), ""); - // Pop the data of last time. - if (!buffers_.empty()) { - buffers_.pop(); - } // Construct device tensors and fill to the buffers from member nodes. - FillDataBuffer(); - if (buffers_.size() == 0) { - SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty."); - } + FillDataBuffer(context); // Allocate memory for device tensors. SendMemoryAllocReq(context); @@ -65,18 +61,13 @@ void DataSourceActor::UpdateOutputData(OpData *const output_data, MS_EXCEPTION_IF_NULL(output_node); MS_EXCEPTION_IF_NULL(context); - if (buffers_.size() == 0) { - SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty."); - } - const auto &output_device_tensors = buffers_.front(); - auto position = FetchNodePosition({output_node, data_arrow->from_output_index_}); // Host data souruce actor uses the node position, device data source actor uses the output index. auto output_position = (position != 0) ? position : IntToSize(data_arrow->from_output_index_); - if (output_position >= output_device_tensors.size()) { + if (output_position >= device_tensors_.size()) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The output index is of range."); } - output_data->data_ = output_device_tensors[output_position]; + output_data->data_ = device_tensors_[output_position]; } void DeviceQueueDataSourceActor::Init() { @@ -101,7 +92,7 @@ void DeviceQueueDataSourceActor::Init() { stream_ = device_contexts_[0]->device_res_manager_->GetStream(kernel_info_->stream_id()); } -void DeviceQueueDataSourceActor::FillDataBuffer() { +void DeviceQueueDataSourceActor::FillDataBuffer(OpContext *const context) { MS_EXCEPTION_IF_NULL(kernel_info_); if (is_dynamic_shape_) { // For GetNext dynamic case, the Resize method finish update output shape and output size in kernel tensor via data @@ -114,38 +105,33 @@ void DeviceQueueDataSourceActor::FillDataBuffer() { } } - // Construct device tensors. - std::vector device_tensors; + device_tensors_.clear(); for (auto &device_tensor : kernel_info_->output_address_list()) { MS_EXCEPTION_IF_NULL(device_tensor); - (void)device_tensors.emplace_back(device_tensor.get()); + (void)device_tensors_.emplace_back(device_tensor.get()); } - - buffers_.push(device_tensors); } void DeviceQueueDataSourceActor::SendMemoryAllocReq(OpContext *const context) { - auto &device_tensors = buffers_.back(); if (ActorDispatcher::is_memory_allocation_sync()) { - ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, + ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors_, device_contexts_[0], context, GetAID()); OnMemoryAllocFinish(context); } else { - ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, + ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors_, device_contexts_[0], context, GetAID()); } } void DeviceQueueDataSourceActor::SendMemoryFreeReq(OpContext *const context) { - auto &device_tensors = buffers_.front(); if (device_contexts_.empty()) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Empty device contexts in device data source actor."); } if (ActorDispatcher::is_memory_free_sync()) { - ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, + ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors_, device_contexts_[0], context, GetAID()); } else { - ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_contexts_[0], + ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors_, device_contexts_[0], context, GetAID()); } } @@ -158,29 +144,24 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext *co if (IsRunningFailed(context)) { return; } - if (buffers_.size() == 0) { - SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty."); - } - // Construct outputs of data kernel launching. - auto &device_tensors = buffers_.back(); - if (output_kernel_tensors_.size() != device_tensors.size()) { + if (output_kernel_tensors_.size() != device_tensors_.size()) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The outputs number is not equal to the device tensors number."); } - for (size_t i = 0; i < device_tensors.size(); ++i) { + for (size_t i = 0; i < device_tensors_.size(); ++i) { MS_EXCEPTION_IF_NULL(output_kernel_tensors_[i]); - MS_EXCEPTION_IF_NULL(device_tensors[i]); - output_kernel_tensors_[i]->set_device_ptr(device_tensors[i]->GetMutablePtr()); - output_kernel_tensors_[i]->set_size(device_tensors[i]->GetSize()); + MS_EXCEPTION_IF_NULL(device_tensors_[i]); + output_kernel_tensors_[i]->set_device_ptr(device_tensors_[i]->GetMutablePtr()); + output_kernel_tensors_[i]->set_size(device_tensors_[i]->GetSize()); if (recorder_aid_ != nullptr || debug_aid_ != nullptr) { - mem_info_.outputs_[i]->addr = device_tensors[i]->GetMutablePtr(); - mem_info_.outputs_[i]->size = device_tensors[i]->GetSize(); + mem_info_.outputs_[i]->addr = device_tensors_[i]->GetMutablePtr(); + mem_info_.outputs_[i]->size = device_tensors_[i]->GetSize(); } } if (debug_aid_ != nullptr) { ActorDispatcher::SendSync(*debug_aid_, &DebugActor::DebugPreLaunch, data_kernel_, std::vector(), - device_tensors, device_contexts_[0], context, &GetAID()); + device_tensors_, device_contexts_[0], context, &GetAID()); } // Copy data from device queue by data kernel launching. @@ -213,7 +194,7 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext *co void DeviceQueueDataSourceActor::SendDebugReq(OpContext *const context) { ActorDispatcher::SendSync(*debug_aid_, &DebugActor::DebugPostLaunch, data_kernel_, std::vector(), - buffers_.back(), device_contexts_[0], context, &GetAID()); + device_tensors_, device_contexts_[0], context, &GetAID()); OnDebugFinish(context); } @@ -226,97 +207,113 @@ void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext *const } void DeviceQueueDataSourceActor::IncreaseNewRefCounts(OpContext *const context) { - if (buffers_.size() == 0) { - SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The device data source actor data queue is empty."); - } - const auto &output_device_tensors = buffers_.front(); for (const auto &data_arrow : output_data_arrows_) { MS_EXCEPTION_IF_NULL(data_arrow); size_t position = IntToSize(data_arrow->from_output_index_); - if (position >= output_device_tensors.size()) { + if (position >= device_tensors_.size()) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Invalid output index:" + std::to_string(position) + - " total size:" + std::to_string(output_device_tensors.size()) + + " total size:" + std::to_string(device_tensors_.size()) + " for device queue data source actor."); } - MS_EXCEPTION_IF_NULL(output_device_tensors[position]); - output_device_tensors[data_arrow->from_output_index_]->IncreaseNewRefCount(); + MS_EXCEPTION_IF_NULL(device_tensors_[position]); + device_tensors_[data_arrow->from_output_index_]->IncreaseNewRefCount(); MS_LOG(DEBUG) << "Increase new ref count for device address:" - << output_device_tensors[data_arrow->from_output_index_]->PrintInfo() << " in actor:" << GetAID(); + << device_tensors_[data_arrow->from_output_index_]->PrintInfo() << " in actor:" << GetAID(); } } void HostQueueDataSourceActor::IncreaseNewRefCounts(OpContext *const context) { - if (buffers_.size() == 0) { - SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The device data source actor data queue is empty."); - } - const auto &output_device_tensors = buffers_.front(); - if (output_data_arrows_.size() != output_data_nodes_.size()) { + runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kRuntime, runtime::ProfilerEvent::kOutputProcess, + "DataSourceActorIncreaseRefCount"); + if (output_data_arrows_.size() != output_data_nodes_.size() || + need_refresh_device_address_.size() != output_data_arrows_.size()) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR( - (*context), "Invalid data arrow size:" + std::to_string(output_data_arrows_.size()) + " and data node size:" + - std::to_string(output_data_nodes_.size()) + " for host queue data source actor."); + (*context), "Invalid data arrow size:" + std::to_string(output_data_arrows_.size()) + + " and data node size:" + std::to_string(output_data_nodes_.size()) + + " and need refresh flag size:" + std::to_string(need_refresh_device_address_.size()) + + " for host queue data source actor."); } for (size_t i = 0; i < output_data_arrows_.size(); ++i) { + if (!need_refresh_device_address_[i]) { + continue; + } auto &data_arrow = output_data_arrows_[i]; auto output_node = output_data_nodes_[i]; MS_EXCEPTION_IF_NULL(data_arrow); MS_EXCEPTION_IF_NULL(output_node); auto position = FetchNodePosition({output_node, data_arrow->from_output_index_}); - if (position >= output_device_tensors.size()) { + if (position >= device_tensors_.size()) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Invalid output index:" + std::to_string(position) + - " total size:" + std::to_string(output_device_tensors.size()) + + " total size:" + std::to_string(device_tensors_.size()) + " for device queue data source actor."); } - MS_EXCEPTION_IF_NULL(output_device_tensors[position]); - output_device_tensors[position]->IncreaseNewRefCount(); - MS_LOG(DEBUG) << "Increase new ref count for device address:" << output_device_tensors[position]->PrintInfo() + MS_EXCEPTION_IF_NULL(device_tensors_[position]); + device_tensors_[position]->IncreaseNewRefCount(); + MS_LOG(DEBUG) << "Increase new ref count for device address:" << device_tensors_[position]->PrintInfo() << " in actor:" << GetAID(); } } -void HostQueueDataSourceActor::FillDataBuffer() { +void HostQueueDataSourceActor::FillDataBuffer(OpContext *const context) { + runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kRuntime, runtime::ProfilerEvent::kOutputProcess, + "DataSourceActorFillDataBuffer"); // Construct device tensors. - std::vector device_tensors; - for (auto &node_with_index : data_node_with_indexs_) { + if (device_tensors_.size() != data_node_with_indexs_.size()) { + std::stringstream ofs; + ofs << "Invalid device tensor size:" << device_tensors_.size() + << " and data node size:" << data_node_with_indexs_.size() << " for actor:" << GetAID(); + SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), ofs.str()); + } + auto update_device_tensor = [this](const KernelWithIndex &node_with_index, size_t index) { auto device_address = AnfAlgo::GetMutableOutputAddr(node_with_index.first, node_with_index.second, false); MS_EXCEPTION_IF_NULL(device_address); MS_LOG(DEBUG) << "Node:" << node_with_index.first->DebugString() << " index:" << node_with_index.second << " device address:" << device_address->PrintInfo(); - (void)device_tensors.emplace_back(device_address.get()); + device_tensors_[index] = device_address.get(); + }; + if (is_shape_match_) { + MS_LOG(DEBUG) << "Fill data in shape match mode, refreash index:" << need_refresh_input_index_ + << " for actor:" << GetAID(); + std::for_each( + need_refresh_input_index_.begin(), need_refresh_input_index_.end(), + [update_device_tensor, this](size_t index) { update_device_tensor(data_node_with_indexs_[index], index); }); + is_shape_match_ = false; + } else { + for (size_t i = 0; i < data_node_with_indexs_.size(); ++i) { + update_device_tensor(data_node_with_indexs_[i], i); + } } for (const auto &pair : heter_index_pair_) { - if (pair.first >= device_tensors.size() || pair.second >= device_tensors.size()) { + if (pair.first >= device_tensors_.size() || pair.second >= device_tensors_.size()) { MS_LOG(EXCEPTION) << "Invalid index:" << pair.first << " " << pair.second - << " device tensor size:" << device_tensors.size() << " for data source actor."; + << " device tensor size:" << device_tensors_.size() << " for data source actor."; } - MS_LOG(DEBUG) << "Add device tensor copy store for device address:" << device_tensors[pair.second] - << " type:" << device_tensors[pair.second]->GetDeviceType() << " and " << device_tensors[pair.first] - << " type:" << device_tensors[pair.first]->GetDeviceType() << " for actor:" << GetAID(); - DeviceTensorCopyStore::GetInstance().Insert(device_tensors[pair.second], device_tensors[pair.first]); + MS_LOG(DEBUG) << "Add device tensor copy store for device address:" << device_tensors_[pair.second] + << " type:" << device_tensors_[pair.second]->GetDeviceType() << " and " << device_tensors_[pair.first] + << " type:" << device_tensors_[pair.first]->GetDeviceType() << " for actor:" << GetAID(); + DeviceTensorCopyStore::GetInstance().Insert(device_tensors_[pair.second], device_tensors_[pair.first]); } - - buffers_.push(device_tensors); } void HostQueueDataSourceActor::SendMemoryAllocReq(OpContext *const context) { if (device_contexts_.empty()) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Empty device contexts in device data source actor."); } - auto &device_tensors = buffers_.back(); if (ActorDispatcher::is_memory_allocation_sync()) { if (IsSameDeviceType()) { - ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, + ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors_, device_contexts_[0], context, GetAID()); } else { - ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::AllocateBatchMemory, &device_tensors, + ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::AllocateBatchMemory, &device_tensors_, &device_contexts_, context, GetAID()); } OnMemoryAllocFinish(context); } else { if (IsSameDeviceType()) { - ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, + ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors_, device_contexts_[0], context, GetAID()); } else { - ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::AllocateBatchMemory, &device_tensors, + ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::AllocateBatchMemory, &device_tensors_, &device_contexts_, context, GetAID()); } } @@ -326,21 +323,20 @@ void HostQueueDataSourceActor::SendMemoryFreeReq(OpContext *const if (device_contexts_.empty()) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Empty device contexts in device data source actor."); } - auto &device_tensors = buffers_.front(); if (ActorDispatcher::is_memory_free_sync()) { if (IsSameDeviceType()) { - ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, + ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors_, device_contexts_[0], context, GetAID()); } else { - ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeBatchMemory, &device_tensors, + ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeBatchMemory, &device_tensors_, &device_contexts_, context, GetAID()); } } else { if (IsSameDeviceType()) { - ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_contexts_[0], + ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors_, device_contexts_[0], context, GetAID()); } else { - ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeBatchMemory, &device_tensors, + ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeBatchMemory, &device_tensors_, &device_contexts_, context, GetAID()); } } @@ -370,6 +366,12 @@ void HostQueueDataSourceActor::AddCopyDataCallBack( } } +namespace { +bool IsEmptyTuple(const tensor::TensorPtr &host_tensor, device::DeviceAddress *const device_tensor) { + return host_tensor->data_ptr() == nullptr && device_tensor->GetSize() == 0; +} +} // namespace + void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext *const context) { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); @@ -377,9 +379,6 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext *cons if (IsRunningFailed(context)) { return; } - if (buffers_.size() == 0) { - SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty."); - } // Get host tensors from host queue and get device tensors from buffers. MS_EXCEPTION_IF_NULL(host_queue_); @@ -387,8 +386,7 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext *cons SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Host data queue is empty."); } auto &host_tensors = host_queue_->Pull(); - auto &device_tensors = buffers_.back(); - if (host_tensors.size() != device_tensors.size()) { + if (host_tensors.size() != device_tensors_.size()) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The length of host tensors is not equal to the length of device tensors."); } @@ -398,11 +396,15 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext *cons // Copy data from host tensor to device tensor. uint64_t start_time = 0; PROFILER_START(start_time); - auto enable_async_copy = (ms_context->IsEnableInferBoost() || is_infer_phase_) && !sync_copy_input; + static const bool enable_infer_boost = ms_context->IsEnableInferBoost(); + auto enable_async_copy = (enable_infer_boost || is_infer_phase_) && !sync_copy_input; try { for (size_t i = 0; i < host_tensors.size(); ++i) { auto &host_tensor = host_tensors[i]; - auto &device_tensor = device_tensors[i]; + auto &device_tensor = device_tensors_[i]; + if (host_tensor == nullptr && enable_infer_boost) { + continue; + } MS_EXCEPTION_IF_NULL(device_tensor); MS_EXCEPTION_IF_NULL(host_tensor); // No used device address need skip. @@ -422,7 +424,7 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext *cons } continue; } - if (host_tensor->data_ptr() == nullptr && device_tensor->GetSize() == 0) { + if (IsEmptyTuple(host_tensor, device_tensor)) { MS_LOG(INFO) << "Empty tuple sync"; continue; } @@ -447,7 +449,7 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext *cons device_tensor->set_host_shape(host_tensor->shape()); } } - AddCopyDataCallBack(enable_async_copy, host_tensors, device_tensors); + AddCopyDataCallBack(enable_async_copy, host_tensors, device_tensors_); } catch (const std::exception &e) { MsException::Instance().SetException(); SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Host data source actor run exception."); @@ -483,6 +485,52 @@ bool HostQueueDataSourceActor::IsSameDeviceType() const { return true; } +void HostQueueDataSourceActor::Init() { + DataSourceActor::Init(); + device_tensors_.resize(data_node_with_indexs_.size()); + need_refresh_device_address_.resize(output_data_nodes_.size(), true); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + static const bool enable_infer_boost = ms_context->IsEnableInferBoost(); + if (!enable_infer_boost) { + need_refresh_input_index_.resize(data_node_with_indexs_.size()); + std::iota(need_refresh_input_index_.begin(), need_refresh_input_index_.end(), 0); + return; + } + + std::set kv_cache_parameters; + for (size_t i = 0; i < data_node_with_indexs_.size(); ++i) { + MS_EXCEPTION_IF_NULL(data_node_with_indexs_[i].first); + const auto &graph = data_node_with_indexs_[i].first->func_graph(); + const auto &kernel_graph = dynamic_cast(graph.get()); + if (kernel_graph == nullptr) { + need_refresh_input_index_.emplace_back(i); + continue; + } + const auto &front_node_with_index = GetFrontNodeByKernelGraph(data_node_with_indexs_[i].first, kernel_graph); + if (front_node_with_index.first != nullptr && + front_node_with_index.first->fullname_with_scope().find("key_cache") == std::string::npos && + front_node_with_index.first->fullname_with_scope().find("value_cache") == std::string::npos) { + need_refresh_input_index_.emplace_back(i); + continue; + } + kv_cache_parameters.emplace(data_node_with_indexs_[i]); + } + if (output_data_nodes_.size() != output_data_arrows_.size()) { + MS_LOG(EXCEPTION) << "Invalid output data node size:" << output_data_nodes_.size() + << " and output data arrow size:" << output_data_arrows_.size() << " for actor:" << GetAID(); + } + for (size_t i = 0; i < output_data_nodes_.size(); ++i) { + MS_EXCEPTION_IF_NULL(output_data_arrows_[i]); + if (kv_cache_parameters.find({output_data_nodes_[i], output_data_arrows_[i]->from_output_index_}) != + kv_cache_parameters.end()) { + need_refresh_device_address_[i] = false; + } + } + MS_LOG(DEBUG) << "Need refresh input index:" << need_refresh_input_index_ + << " and need update device tensor flag:" << need_refresh_device_address_ << " for actor:" << GetAID(); +} + void HostQueueDataSourceActor::ReleaseData() { runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kRuntime, runtime::ProfilerEvent::kOutputProcess, "DataSourceActorReleaseData"); @@ -491,7 +539,8 @@ void HostQueueDataSourceActor::ReleaseData() { host_queue_->Pop(); // The step end need release data node address. - for (auto &data_node_with_index : data_node_with_indexs_) { + for (size_t i : need_refresh_input_index_) { + const auto &data_node_with_index = data_node_with_indexs_[i]; if (!AnfAlgo::OutputAddrExist(data_node_with_index.first, data_node_with_index.second)) { continue; } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_source_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_source_actor.h index e494213d0fe..07cf1ab0361 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_source_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_source_actor.h @@ -44,8 +44,7 @@ class DataSourceActor : public DebugAwareActor { public: DataSourceActor(const std::string &name, KernelTransformType type, size_t buffer_capacity, const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid) - : DebugAwareActor(name, type, recorder_aid, memory_manager_aid, debug_aid, nullptr), - buffer_capacity_(buffer_capacity) {} + : DebugAwareActor(name, type, recorder_aid, memory_manager_aid, debug_aid, nullptr) {} ~DataSourceActor() override = default; virtual void ReleaseData() {} @@ -63,14 +62,11 @@ class DataSourceActor : public DebugAwareActor { void FetchData(OpContext *const context); // Construct the device tensors and fill to device tensor buffer from the member nodes during the data fetching. - virtual void FillDataBuffer() = 0; + virtual void FillDataBuffer(OpContext *const context) = 0; void UpdateOutputData(OpData *const output_data, const DataArrowPtr &data_arrow, const AnfNodePtr &output_node, OpContext *const context) override; - - // The buffers store the device tensors. - std::queue> buffers_; - size_t buffer_capacity_; + std::vector device_tensors_; }; // The class represents that the data source is device queue. @@ -97,7 +93,7 @@ class DeviceQueueDataSourceActor : public DataSourceActor { protected: void Init() override; - void FillDataBuffer() override; + void FillDataBuffer(OpContext *const context) override; void SendRecorderInfo(OpContext *const context) const override; private: @@ -142,11 +138,13 @@ class HostQueueDataSourceActor : public DataSourceActor { size_t FetchNodePosition(const KernelWithIndex &node) const override; KernelWithIndex FetchNode(size_t node_position) const; const std::vector &data_nodes() const { return data_node_with_indexs_; } - + bool is_shape_match() { return is_shape_match_; } + void set_is_shape_match(bool is_shape_match) { is_shape_match_ = is_shape_match; } void ReleaseData() override; protected: - void FillDataBuffer() override; + void Init() override; + void FillDataBuffer(OpContext *const context) override; void AddCopyDataCallBack(bool enable_async_copy, const mindspore::tensor::TensorPtrList &host_tensors, const std::vector &device_tensors); @@ -170,6 +168,9 @@ class HostQueueDataSourceActor : public DataSourceActor { // Whether the super kernel actor is a infer 'prefill' or 'increment' graph or not. bool is_infer_phase_; + bool is_shape_match_{false}; + std::vector need_refresh_input_index_; + std::vector need_refresh_device_address_; }; using DataSourceActorPtr = std::shared_ptr; diff --git a/tests/st/runtime/test_parallel_dispatch.py b/tests/st/runtime/test_parallel_dispatch.py index 4ed750231c0..7e981f2f893 100644 --- a/tests/st/runtime/test_parallel_dispatch.py +++ b/tests/st/runtime/test_parallel_dispatch.py @@ -18,7 +18,7 @@ import numpy as np import mindspore as ms import mindspore.nn as nn import mindspore.ops as P -from mindspore import Tensor +from mindspore import Tensor, mutable from mindspore import dtype as mstype import mindspore.context as context from tests.mark_utils import arg_mark @@ -43,15 +43,16 @@ class Net(nn.Cell): self.add = P.Add() self.mul = P.Mul() self.sub = P.Sub() + self.add_n = P.AddN() self.reshape = P.Reshape() - def construct(self, x): + def construct(self, x, key_cache_list, value_cache_list): x = self.reshape(x, (1, -1)) for _ in range(g_block_num): x = self.add(x, 1) x = self.sub(x, 1.1) - x = self.reshape(x, (3, -1)) + x = self.reshape(x, (2, -1)) x = self.mul(x, 0.251) x = self.add(x, 1) @@ -61,8 +62,11 @@ class Net(nn.Cell): x = self.mul(x, 2) x = self.add(x, 1) x = self.sub(x, 1.1) - x = self.reshape(x, (6, -1)) + x = self.reshape(x, (4, -1)) x = self.mul(x, 0.051) + x = self.reshape(x, (2, -1)) + x = self.add_n(key_cache_list) + x + x = self.add_n(value_cache_list) + x x = self.reshape(x, (2, -1)) return x @@ -76,22 +80,46 @@ def test_host_bound_for_parallel_dispatch(): internal kernels. Expectation: The program execute and exit normally. """ - input_data = Tensor(np.zeros((2, 3)).astype(np.float32)) + input_data1 = Tensor(np.zeros((2, 2)).astype(np.float32)) + input_data2 = Tensor(np.zeros((2, 4)).astype(np.float32)) dyn_input_data = Tensor(shape=[2, None], dtype=mstype.float32) + k_cache_list1 = [] + v_cache_list1 = [] + k_cache_list2 = [] + v_cache_list2 = [] + dyn_k_cache_list = [] + dyn_v_cache_list = [] + + for _ in range(10): + dyn_k_cache_list.append(dyn_input_data) + dyn_v_cache_list.append(dyn_input_data) + + for _ in range(10): + new_input_data = P.Add()(input_data1, 1) + k_cache_list1.append(new_input_data) + v_cache_list1.append(new_input_data) net = Net() - net.set_inputs(dyn_input_data) + net.set_inputs(dyn_input_data, mutable(dyn_k_cache_list), mutable(dyn_v_cache_list)) net.phase = "increment" + # warm up - output = net(input_data) - output = net(input_data) + output = net(input_data1, mutable(k_cache_list1), mutable(v_cache_list1)) + output = net(input_data1, mutable(k_cache_list1), mutable(v_cache_list1)) print(output) + k_cache_list1 = [] + v_cache_list1 = [] + + for _ in range(10): + new_input_data = P.Add()(input_data2, 1) + k_cache_list2.append(new_input_data) + v_cache_list2.append(new_input_data) for _ in range(steps): - output = net(input_data) + output = net(input_data2, mutable(k_cache_list2), mutable(v_cache_list2)) output.asnumpy() - exp_val = -0.06835 - exp_array = np.array([[exp_val, exp_val, exp_val], [exp_val, exp_val, exp_val]]) + exp_val = 20.191507 + exp_array = np.array([[exp_val, exp_val, exp_val, exp_val], [exp_val, exp_val, exp_val, exp_val]]) assert np.allclose(output.asnumpy(), exp_array, 0.0001, 0.0001) -- Gitee From 544a683221473f0321178d8f887fb3c8a615eb2c Mon Sep 17 00:00:00 2001 From: Erpim Date: Fri, 21 Mar 2025 21:34:54 +0800 Subject: [PATCH 14/14] clean code & fix conflict --- .jenkins/check/config/filter_linklint.txt | 5 ++++- .../device/ascend/kernel/internal/fused_add_topk_div.cc | 2 +- mindspore/core/utils/ms_context.cc | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.jenkins/check/config/filter_linklint.txt b/.jenkins/check/config/filter_linklint.txt index 2223c5b57c5..68287dd6ae5 100644 --- a/.jenkins/check/config/filter_linklint.txt +++ b/.jenkins/check/config/filter_linklint.txt @@ -70,4 +70,7 @@ https://www.mindspore.cn*/r2.3.1/* https://mindspore.cn*/r2.4.0/* https://www.mindspore.cn*/r2.4.0/* https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.dataset.loading.html* -https://www.mindspore.cn/docs/en/master/api_python/mindspore.dataset.loading.html* \ No newline at end of file +https://www.mindspore.cn/docs/en/master/api_python/mindspore.dataset.loading.html* +https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html +https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#mindspore-user-defined-data-types +https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html#annotation-type \ No newline at end of file diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/fused_add_topk_div.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/fused_add_topk_div.cc index be7d5a8b561..f40e1bd53d2 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/fused_add_topk_div.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/fused_add_topk_div.cc @@ -17,7 +17,7 @@ #include "plugin/device/ascend/kernel/internal/fused_add_topk_div.h" #include -#include "kernel/kernel.h" +#include "common/kernel.h" #include "plugin/device/ascend/kernel/internal/internal_kernel_in_out_map.h" namespace mindspore { diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index 1a31d057868..6c289b8245e 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -742,7 +742,8 @@ void MsContext::SetMsInternalEnableCustomKernelList() { const std::string kDefaultEnabledOpList = "MatMul,RmsNorm,Add,Sub,FlashAttentionScore,PagedAttention,PagedAttentionMask,AddRmsNorm,AddLayerNorm," "MatMulAllReduce,InferenceMatmulSplit,AddRmsNormQuantV2,InferenceSwiGLU,QbmmAllReduceAdd,QbmmAdd," - "AddRmsNormDynamicQuant,MatMulElemwise,RmsNormQuant,MatMulSigmoidCastAdd,TransposeBatchMatmulTranspose,FusedAddTopKDiv"; + "AddRmsNormDynamicQuant,MatMulElemwise,RmsNormQuant,MatMulSigmoidCastAdd,TransposeBatchMatmulTranspose," + "FusedAddTopKDiv"; const std::string k310pDefaultEnabledOpList = "MatMul,QuantBatchMatmul,QuantLinearSparse,QbmmAllReduceAdd,QbmmAdd"; auto internal_op_boost_env = common::GetEnv("MS_ENABLE_INTERNAL_BOOST"); bool is_enable_internal_op = true; -- Gitee