diff --git a/docs/README.md b/docs/README.md index b688a7f7960ad7bec14e33c5eeb99b51da3e196b..b90de102a7bd4bd32cad0d34380f6b855122a0ba 100644 --- a/docs/README.md +++ b/docs/README.md @@ -334,10 +334,14 @@ 给定两个源操作数src0和src1,根据maskTensor相应位置的值(非bit位)选取元素,得到目的操作数dst。 - 变形 + 变形 ConfusionTranspose 对输入数据进行数据排布及Reshape操作。 + + TransData + 对输入数据排布格式转换为输出所需的数据排布格式 + 索引操作 ArithProgression diff --git a/impl/CMakeLists.txt b/impl/CMakeLists.txt index c29d95d33e3a958ae66a96848fc92ab7cc2df526..1ab2cc72b6dc11b43344c7529011edcaa3002e4e 100644 --- a/impl/CMakeLists.txt +++ b/impl/CMakeLists.txt @@ -92,6 +92,7 @@ add_library(tiling_api STATIC ${CMAKE_CURRENT_SOURCE_DIR}/math/axpy/axpy_tiling_impl.cpp ${CMAKE_CURRENT_SOURCE_DIR}/math/ceil/ceil_tiling_impl.cpp ${CMAKE_CURRENT_SOURCE_DIR}/math/floor/floor_tiling_impl.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/transpose/transdata/transdata_tiling.cpp ${CMAKE_CURRENT_SOURCE_DIR}/math/fmod/fmod_tiling_impl.cpp ${CMAKE_CURRENT_SOURCE_DIR}/math/trunc/trunc_tiling_impl.cpp $<$:$> diff --git a/impl/reduce/mean/mean_tiling.cpp b/impl/reduce/mean/mean_tiling.cpp index c22cacec074b9811d9809985aa87a5ac86eaa054..6838fb099b180eeb5f515a295fef925b59c40232 100644 --- a/impl/reduce/mean/mean_tiling.cpp +++ b/impl/reduce/mean/mean_tiling.cpp @@ -10,16 +10,37 @@ #include "lib/reduce/mean_tiling.h" #include "register/tilingdata_base.h" +#include "impl/host_log.h" + namespace AscendC { constexpr uint32_t MEAN_CALC_PROC = 1; const uint32_t MEAN_ONE_BLK_SIZE = 32; const uint32_t MEAN_ONE_REPEAT_BYTE_SIZE = 256; const uint32_t HALF_TYPE_SIZE = 2; const uint32_t FLOAT_TYPE_SIZE = 4; + +inline void CheckMeanHostParams(const uint32_t n, const uint32_t srcTypeSize, + const uint32_t accTypeSize, const bool isReuseSource) +{ + ASCENDC_HOST_ASSERT( + ((srcTypeSize == 2U && accTypeSize == 2U) || + (srcTypeSize == 4U && accTypeSize == 4U) || + (srcTypeSize == 2U && srcTypeSize == 4U)), + return, + "[Mean][GetMeanMaxMinTmpSize] The parameter (srcTypeSize, accTypeSize) is (%u, %u), expected is (2, 2)/(4, 4)/(2, 4).", + srcTypeSize, accTypeSize + ); + ASCENDC_HOST_ASSERT(n > 0, + return, "[Mean][GetMeanMaxMinTmpSize] The parameter n is %u, expected is greater than 0!", n); + if (isReuseSource) { + TILING_LOG_WARNING("[Mean][GetMeanMaxMinTmpSize] The parameter isReuseSource is true, which is not effective!"); + } +} + void GetMeanMaxMinTmpSize(const uint32_t n, const uint32_t srcTypeSize, const uint32_t accTypeSize, const bool isReuseSource, uint32_t& maxSize, uint32_t& minSize) { - (void)isReuseSource; + CheckMeanHostParams(n, srcTypeSize, accTypeSize, isReuseSource); if (srcTypeSize == 0) { return; } diff --git a/impl/reduce/reduce_all/reduce_all_v220_impl.h b/impl/reduce/reduce_all/reduce_all_v220_impl.h index 223639d81a0d045b1dc23027e43adaff296f44a3..aa5386bb32996f9e556a8303391f5ae897ed8105 100644 --- a/impl/reduce/reduce_all/reduce_all_v220_impl.h +++ b/impl/reduce/reduce_all/reduce_all_v220_impl.h @@ -64,9 +64,9 @@ __aicore__ inline void ReduceAllImpl(const LocalTensor& dstTensor, const Loca BinaryReduceByFirstAxis>( dstTensor, srcTensor, tmpTensor, first, last, padLast); } - SetMaskNorm(); - ResetMask(); } + SetMaskNorm(); + ResetMask(); } } // namespace Internal } // namespace AscendC diff --git a/impl/reduce/reduce_any/reduce_any_v220_impl.h b/impl/reduce/reduce_any/reduce_any_v220_impl.h index b8d7f59dcc5501bd58eb617bf13897fac93125ed..5c3bd25208cde6324fd2b49892198051fca00d50 100644 --- a/impl/reduce/reduce_any/reduce_any_v220_impl.h +++ b/impl/reduce/reduce_any/reduce_any_v220_impl.h @@ -64,9 +64,9 @@ __aicore__ inline void ReduceAnyImpl(const LocalTensor& dstTensor, const Loca BinaryReduceByFirstAxis>( dstTensor, srcTensor, tmpTensor, first, last, padLast); } - SetMaskNorm(); - ResetMask(); } + SetMaskNorm(); + ResetMask(); } } // namespace Internal } // namespace AscendC diff --git a/impl/reduce/reduce_sum/reduce_sum_v220_impl.h b/impl/reduce/reduce_sum/reduce_sum_v220_impl.h index 75890b67bd6874eeb1f04ae4e7d606382513d9fa..f29211873c3b887f94091092e91715171f6261c6 100644 --- a/impl/reduce/reduce_sum/reduce_sum_v220_impl.h +++ b/impl/reduce/reduce_sum/reduce_sum_v220_impl.h @@ -268,6 +268,8 @@ __aicore__ inline void ReduceSumImpl(const LocalTensor& dstTensor, const Loca BinaryReduceByFirstAxis>( dstTensor, srcTensor, tmpBuf, first, last, padLast); } + SetMaskNorm(); + ResetMask(); } } // namespace Internal } // namespace AscendC diff --git a/impl/reduce/reduce_tiling.cpp b/impl/reduce/reduce_tiling.cpp index b427d5b41220b56ffda6cfda5cafb9d81b8e4344..71722928c3a4c898614e4c5f2ba0528dcbd113c0 100644 --- a/impl/reduce/reduce_tiling.cpp +++ b/impl/reduce/reduce_tiling.cpp @@ -10,6 +10,7 @@ #include "lib/reduce/reduce_tiling.h" +#include #include #include @@ -44,27 +45,31 @@ uint32_t FindK(uint32_t n) { } inline void CheckParams(std::vector shapeDims, bool isSrcInnerPad, ReducePattern pattern, - uint32_t first, uint32_t last) + uint32_t first, uint32_t last, std::string apiName, std::string funcName) { - ASCENDC_HOST_ASSERT(shapeDims.size() == ALLOWED_SHAPE_DIM, return, "srcShape dims must be 2."); - ASCENDC_HOST_ASSERT(isSrcInnerPad, return, "isSrcInnerPad must be true on this platform."); + ASCENDC_HOST_ASSERT(shapeDims.size() == ALLOWED_SHAPE_DIM, return, + "[%s][%s] srcShape dims must be 2.", apiName.c_str(), funcName.c_str()); + ASCENDC_HOST_ASSERT(isSrcInnerPad, return, + "[%s][%s] isSrcInnerPad must be true on this platform.", apiName.c_str(), funcName.c_str()); ASCENDC_HOST_ASSERT(pattern == ReducePattern::AR || pattern == ReducePattern::RA, return, - "Currently only support AR and RA pattern."); - ASCENDC_HOST_ASSERT(first > 0 && last > 0, return, "both first and last axis must be greater than 0."); + "[%s][%s] Currently only support AR and RA pattern.", apiName.c_str(), funcName.c_str()); + ASCENDC_HOST_ASSERT(first > 0 && last > 0, return, + "[%s][%s] both first and last axis must be greater than 0.", apiName.c_str(), funcName.c_str()); } } // namespace void GetReduceCommonMaxMinTmpSize(const ge::Shape &srcShape, const ge::DataType dataType, ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource, - uint32_t &maxValue, uint32_t &minValue, bool isBinaryAdd) + uint32_t &maxValue, uint32_t &minValue, bool isBinaryAdd, + std::string apiName, std::string funcName) { std::vector shapeDims = srcShape.GetDims(); const uint32_t first = static_cast(shapeDims[0]); const uint32_t last = static_cast(shapeDims[1]); - CheckParams(shapeDims, isSrcInnerPad, pattern, first, last); + CheckParams(shapeDims, isSrcInnerPad, pattern, first, last, apiName, funcName); if (isReuseSource) { maxValue = minValue = 0U; return; @@ -96,18 +101,77 @@ void GetReduceCommonMaxMinTmpSize(const ge::Shape &srcShape, maxValue = minValue = k * ((last * GetTypeSize(dataType) + ONE_BLK_SIZE - 1u) / ONE_BLK_SIZE * ONE_BLK_SIZE); } +inline void GetReduceSumMeanCommonTmpSize(const ge::Shape &srcShape, + ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource, + uint32_t &maxValue, uint32_t &minValue, std::string apiName, std::string funcName) +{ + std::vector shapeDims = srcShape.GetDims(); + const uint32_t first = static_cast(shapeDims[0]); + const uint32_t last = static_cast(shapeDims[1]); + CheckParams(shapeDims, isSrcInnerPad, pattern, first, last, apiName, funcName); + if (isReuseSource) { + maxValue = minValue = 0U; + return; + } + uint32_t elePerBlk = ONE_BLK_SIZE / FLOAT_TYPE_SIZE; + if (pattern == ReducePattern::AR) { + uint32_t k = FindK(last); + if (k == last && first > 1U) { + k >>= 1U; + } + if (last <= B32_ELEM_NUM_PER_REPEAT) { + maxValue = minValue = 0U; + } else { + maxValue = minValue = (first * k) * FLOAT_TYPE_SIZE; + } + } else { + uint32_t k = FindK(first); + uint32_t padLast = (last + elePerBlk - 1U) / elePerBlk * elePerBlk; + if (first == k && first > 1U) { + k >>= 1U; + } + maxValue = minValue = (k * padLast) * FLOAT_TYPE_SIZE; + } + return; +} + +inline void GetReduceAnyAllCommonTmpSize(const ge::Shape &srcShape, + ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource, + uint32_t &maxValue, uint32_t &minValue, std::string apiName, std::string funcName) +{ + std::vector shapeDims = srcShape.GetDims(); + const uint32_t first = static_cast(shapeDims[0]); + const uint32_t last = static_cast(shapeDims[1]); + CheckParams(shapeDims, isSrcInnerPad, pattern, first, last, apiName, funcName); + if (pattern == ReducePattern::AR) { + uint32_t elePerBlk = static_cast(ONE_BLK_SIZE / sizeof(uint8_t)); + uint32_t padLast = (last + elePerBlk - 1U) / elePerBlk * elePerBlk; + minValue = maxValue = static_cast(padLast * sizeof(uint16_t)) + (first * elePerBlk); + } else { + if (isReuseSource) { + maxValue = minValue = 0U; + return; + } + uint32_t k = FindK(first); + if (k == first && first > 1U) { + k >>= 1U; + } + maxValue = minValue = k * ((last + ONE_BLK_SIZE - 1U) / ONE_BLK_SIZE * ONE_BLK_SIZE); + } + return; +} + void GetReduceProdMaxMinTmpSize(const ge::Shape &srcShape, const ge::DataType dataType, ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource, uint32_t &maxValue, uint32_t &minValue) { - ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT, return, "it only supports float type on this platform."); - + ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT, return, + "[ReduceProd][GetReduceProdMaxMinTmpSize] it only supports float type on this platform."); std::vector shapeDims = srcShape.GetDims(); - const uint32_t first = static_cast(shapeDims[0]); const uint32_t last = static_cast(shapeDims[1]); - CheckParams(shapeDims, isSrcInnerPad, pattern, first, last); + CheckParams(shapeDims, isSrcInnerPad, pattern, first, last, "ReduceProd", "GetReduceProdMaxMinTmpSize"); if (isReuseSource) { minValue = pattern == ReducePattern::AR ? ONE_REPEAT_BYTE_SIZE : 0U; maxValue = minValue; @@ -137,8 +201,9 @@ void GetReduceMaxMaxMinTmpSize(const ge::Shape &srcShape, { ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT || dataType == ge::DT_FLOAT16, return, - "it only supports float and half type on this platform."); - GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, false); + "[ReduceMax][GetReduceMaxMaxMinTmpSize] it only supports float and half type on this platform."); + GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, false, + "ReduceMax", "GetReduceMaxMaxMinTmpSize"); } void GetReduceMinMaxMinTmpSize(const ge::Shape &srcShape, @@ -148,8 +213,9 @@ void GetReduceMinMaxMinTmpSize(const ge::Shape &srcShape, { ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT || dataType == ge::DT_FLOAT16, return, - "it only supports float and half type on this platform."); - GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, false); + "[ReduceMin][GetReduceMinMaxMinTmpSize] it only supports float and half type on this platform."); + GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, false, + "ReduceMin", "GetReduceMinMaxMinTmpSize"); } void GetReduceAnyMaxMinTmpSize(const ge::Shape &srcShape, @@ -157,13 +223,15 @@ void GetReduceAnyMaxMinTmpSize(const ge::Shape &srcShape, ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource, uint32_t &maxValue, uint32_t &minValue) { + ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT || dataType == ge::DT_UINT8, + return, + "[ReduceAny][GetReduceAnyMaxMinTmpSize] it only supports float and uint8_t type on this platform."); if (dataType == ge::DT_UINT8) { - GetReduceAllMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue); + GetReduceAnyAllCommonTmpSize(srcShape, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, + "ReduceAny", "GetReduceAnyMaxMinTmpSize"); } else { - ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT || dataType == ge::DT_UINT8, - return, - "it only supports float and uint8_t type on this platform."); - GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, false); + GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, + false, "ReduceAny", "GetReduceAnyMaxMinTmpSize"); } } @@ -173,31 +241,13 @@ void GetReduceAllMaxMinTmpSize(const ge::Shape &srcShape, uint32_t &maxValue, uint32_t &minValue) { ASCENDC_HOST_ASSERT((dataType == ge::DT_FLOAT || dataType == ge::DT_UINT8), return, - "it only supports float uint8 type on this platform."); - std::vector shapeDims = srcShape.GetDims(); - const uint32_t first = static_cast(shapeDims[0]); - const uint32_t last = static_cast(shapeDims[1]); - CheckParams(shapeDims, isSrcInnerPad, pattern, first, last); + "[ReduceAll][GetReduceAllMaxMinTmpSize] it only supports float and uint8 type on this platform."); if (dataType == ge::DT_UINT8) { - if (pattern == ReducePattern::AR) { - uint32_t elePerBlk = static_cast(ONE_BLK_SIZE / sizeof(uint8_t)); - uint32_t padLast = (last + elePerBlk - 1U) / elePerBlk * elePerBlk; - minValue = maxValue = static_cast(padLast * sizeof(uint16_t)) + (first * elePerBlk); - } else { - if (isReuseSource) { - maxValue = minValue = 0U; - return; - } - uint32_t k = FindK(first); - if (k == first && first > 1U) { - k >>= 1U; - } - maxValue = minValue = k * ((last + ONE_BLK_SIZE - 1U) / ONE_BLK_SIZE * ONE_BLK_SIZE); - } - return; + GetReduceAnyAllCommonTmpSize(srcShape, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, + "ReduceAll", "GetReduceAllMaxMinTmpSize"); } else { - GetReduceCommonMaxMinTmpSize( - srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, false); + GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, + false, "ReduceAll", "GetReduceAllMaxMinTmpSize"); } } @@ -206,35 +256,10 @@ void GetReduceSumMaxMinTmpSize(const ge::Shape &srcShape, ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource, uint32_t &maxValue, uint32_t &minValue) { - ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT, return, "it only supports float type on this platform."); - std::vector shapeDims = srcShape.GetDims(); - const uint32_t first = static_cast(shapeDims[0]); - const uint32_t last = static_cast(shapeDims[1]); - CheckParams(shapeDims, isSrcInnerPad, pattern, first, last); - if (isReuseSource) { - maxValue = minValue = 0U; - return; - } - uint32_t elePerBlk = ONE_BLK_SIZE / FLOAT_TYPE_SIZE; - if (pattern == ReducePattern::AR) { - uint32_t k = FindK(last); - if (k == last && first > 1U) { - k >>= 1U; - } - if (last <= B32_ELEM_NUM_PER_REPEAT) { - maxValue = minValue = 0U; - } else { - maxValue = minValue = (first * k) * FLOAT_TYPE_SIZE; - } - } else { - uint32_t k = FindK(first); - uint32_t padLast = (last + elePerBlk - 1U) / elePerBlk * elePerBlk; - if (first == k && first > 1U) { - k >>= 1U; - } - maxValue = minValue = (k * padLast) * FLOAT_TYPE_SIZE; - } - return; + ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT, return, + "[ReduceSum][GetReduceSumMaxMinTmpSize] it only supports float type on this platform."); + GetReduceSumMeanCommonTmpSize(srcShape, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, + "ReduceSum", "GetReduceSumMaxMinTmpSize"); } void GetReduceMeanMaxMinTmpSize(const ge::Shape &srcShape, @@ -242,6 +267,9 @@ void GetReduceMeanMaxMinTmpSize(const ge::Shape &srcShape, ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource, uint32_t &maxValue, uint32_t &minValue) { - GetReduceSumMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue); + ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT, return, + "[ReduceMean][GetReduceMeanMaxMinTmpSize] it only supports float type on this platform."); + GetReduceSumMeanCommonTmpSize(srcShape, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue, + "ReduceMean", "GetReduceMeanMaxMinTmpSize"); } } // namespace AscendC diff --git a/impl/reduce/reduce_xor_sum/reduce_xor_sum_tiling.cpp b/impl/reduce/reduce_xor_sum/reduce_xor_sum_tiling.cpp index d705e532a8b02a9c55e29143bb5a2a346e1abbdd..244456092f3ea848a0fcbd09f77ffaa73af861e3 100644 --- a/impl/reduce/reduce_xor_sum/reduce_xor_sum_tiling.cpp +++ b/impl/reduce/reduce_xor_sum/reduce_xor_sum_tiling.cpp @@ -32,7 +32,16 @@ void GetReduceXorSumMaxMinTmpSize(const ge::Shape &srcShape, const uint32_t type uint32_t &maxValue, uint32_t &minValue) { const uint32_t inputSize = srcShape.GetShapeSize(); - ASCENDC_HOST_ASSERT(inputSize > 0, return, "ReduceXorSum input Shape size must be greater than 0."); + + ASCENDC_HOST_ASSERT(inputSize > 0U, return, + "[ReduceXorSum][GetReduceXorSumMaxMinTmpSize] The parameter srcShape size is %u, expected is greater than 0!", + inputSize); + std::vector shapeDims = srcShape.GetDims(); + ASCENDC_HOST_ASSERT(shapeDims.size() > 0UL, return, + "[ReduceXorSum][GetReduceXorSumMaxMinTmpSize] The parameter srcShape dimension number is %lli, expected is greater than 0!", + shapeDims.size()); + ASCENDC_HOST_ASSERT(typeSize == 2U, return, + "[ReduceXorSum][GetReduceXorSumMaxMinTmpSize] The parameter typeSize is %u, expected is 2!", typeSize); maxValue = GetTmpSize(inputSize, typeSize, isReuseSource); minValue = maxValue; diff --git a/impl/reduce/sum/sum_tiling.cpp b/impl/reduce/sum/sum_tiling.cpp index 95e087acdf5e2e067197672e247c91a44c231b60..817031cd1025239c84280ad2f680ab1e1de51afa 100644 --- a/impl/reduce/sum/sum_tiling.cpp +++ b/impl/reduce/sum/sum_tiling.cpp @@ -19,8 +19,14 @@ namespace AscendC { void GetSumMaxMinTmpSize( const uint32_t n, const uint32_t typeSize, const bool isReuseSource, uint32_t &maxSize, uint32_t &minSize) { - (void)isReuseSource; - ASCENDC_HOST_ASSERT(typeSize > 0, return, "typeSize must be greater than 0."); + if (isReuseSource) { + TILING_LOG_WARNING("[Sum][GetSumMaxMinTmpSize] The parameter isReuseSource is true, which is not effective!"); + } + ASCENDC_HOST_ASSERT(typeSize > 0, return, + "[Sum][GetSumMaxMinTmpSize] The parameter typeSize is %u, expected is 2 or 4!", typeSize); + ASCENDC_HOST_ASSERT(n > 0, + return, "[Sum][GetSumMaxMinTmpSize] The parameter n is %u, expected is greater than 0!", n); + constexpr uint32_t sumOneBlkSize = 32; constexpr uint32_t sumOneRepeatByteSize = 256; diff --git a/impl/transpose/transdata/transdata_impl.h b/impl/transpose/transdata/transdata_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..d8e5d41e273cfaaa40b5f88191409bd51e873c2f --- /dev/null +++ b/impl/transpose/transdata/transdata_impl.h @@ -0,0 +1,526 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef IMPL_TRANSPOSE_TRANSDATA_TRANSDATA_IMPL_H +#define IMPL_TRANSPOSE_TRANSDATA_TRANSDATA_IMPL_H + +#include "kernel_tensor.h" +#include "kernel_operator_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "../../common/check.h" +#include "../../api_check/kernel_api_check.h" + +namespace AscendC { +namespace Internal { + +namespace { +constexpr int32_t n0 = 16; +constexpr int32_t c0 = 16; +constexpr int32_t hw0 = 16; +constexpr int32_t ncdhwDims = 5; +constexpr int32_t fractalZ3DDims = 7; +constexpr int32_t ndc1hwc0Dims = 6; +} + +struct TransDataTmpParams { + int32_t n; + int32_t c; + int32_t d; + int32_t h; + int32_t w; + int32_t n1; + int32_t c1; + int32_t padHw; +}; + +constexpr int32_t DEFAULT_TRANSDATA_5HD_LIST = 16; + +template +__aicore__ inline void DC1Hwn1n0c0ToC1DHwn1n0c0HWAlign(const LocalTensor& dst, const LocalTensor& src, + const TransDataTmpParams& params) +{ + // d, c1, h w n1 n0 c0 -> c1, d, hw1*hw0 n1 n0 c0 + int32_t d = params.d; + int32_t h = params.h; + int32_t w = params.w; + int32_t n1 = params.n1; + int32_t c1 = params.c1; + int32_t padHw = params.padHw; + + uint32_t dim0 = d; + uint32_t dim1 = c1; + uint32_t lastDim = h * w * n1 * n0 * c0; + + // dim0, dim1, lastDim -> dim1, dim0, lastDim + int32_t n1n0c0DimElems = n1 * n0 * c0; + int32_t hwAlignElems = padHw * n1n0c0DimElems; + int32_t hwPadElems = (padHw - h * w) * n1n0c0DimElems; + + uint16_t blockCount = dim1; + uint16_t blockLen = lastDim * sizeof(T) / ONE_BLK_SIZE; + uint16_t srcGap = 0; + uint16_t dstGap = ((dim0 - 1) * hwAlignElems + hwPadElems) * sizeof(T) / ONE_BLK_SIZE; + + uint32_t dstSize = c1 * d * padHw * n1 * n0 * c0; + Duplicate(dst, static_cast(0), dstSize); + PipeBarrier(); + + DataCopyParams dataCopyParams = { blockCount, blockLen, srcGap, dstGap }; + for (uint32_t d0 = 0; d0 < dim0; d0++) { + DataCopy(dst[d0 * hwAlignElems], src[d0 * dim1 * lastDim], dataCopyParams); + } + PipeBarrier(); +} + +template +__aicore__ inline void C1Dhwn1n0c0ToC1C0Dhwn1n0(const LocalTensor& dst, const LocalTensor& src, + const TransDataTmpParams& params) +{ + // C1 DHWN1N0 C0 -> C1 C0 DHWN1N0 + int32_t d = params.d; + int32_t n1 = params.n1; + int32_t c1 = params.c1; + int32_t padHw = params.padHw; + + TransDataTo5HDParams transDataParams; + transDataParams.dstHighHalf = false; + transDataParams.srcHighHalf = false; + transDataParams.repeatTimes = d * padHw * n1; + if (transDataParams.repeatTimes == 1) { + transDataParams.srcRepStride = 0; + transDataParams.dstRepStride = 0; + } else { + transDataParams.srcRepStride = DEFAULT_TRANSDATA_5HD_LIST * c0 * sizeof(T) / ONE_BLK_SIZE; + transDataParams.dstRepStride = n0 * sizeof(T) / ONE_BLK_SIZE; + } + + uint64_t srcOffsetArr[DEFAULT_TRANSDATA_5HD_LIST]; + uint64_t dstOffsetArr[DEFAULT_TRANSDATA_5HD_LIST]; + uint64_t srcAddr = (uint64_t)src.GetPhyAddr(); + uint64_t dstAddr = (uint64_t)dst.GetPhyAddr(); + for (uint32_t j = 0; j < c1; j++) { + uint32_t outOffset = j * d * padHw * n1 * n0 * c0; + for (uint8_t i = 0; i < DEFAULT_TRANSDATA_5HD_LIST; i++) { + srcOffsetArr[i] = (uint64_t)(srcAddr + (outOffset + i * n0) * sizeof(T)); + dstOffsetArr[i] = (uint64_t)(dstAddr + (outOffset + i * d * padHw * n1 * n0) * sizeof(T)); + } + TransDataTo5HD(dstOffsetArr, srcOffsetArr, transDataParams); + } + PipeBarrier(); +} + +template +__aicore__ inline void C1c0dhwN1n0ToNcdhw(const LocalTensor& dst, const LocalTensor& src, + const LocalTensor& tmp, const TransDataTmpParams& params) +{ + // C1C0DHW N1N0 -> N CDHW + int32_t d = params.d; + int32_t n1 = params.n1; + int32_t padHw = params.padHw; + int32_t currN = params.n; + int32_t c = params.c; + + TransDataTo5HDParams transDataParams; + transDataParams.dstHighHalf = false; + transDataParams.srcHighHalf = false; + transDataParams.repeatTimes = c * d * padHw / n0; + if (transDataParams.repeatTimes == 1) { + transDataParams.srcRepStride = 0; + transDataParams.dstRepStride = 0; + } else { + transDataParams.srcRepStride = DEFAULT_TRANSDATA_5HD_LIST * n1 * n0 * sizeof(T) / ONE_BLK_SIZE; + transDataParams.dstRepStride = c0 * sizeof(T) / ONE_BLK_SIZE; + } + + uint64_t srcOffsetArr[DEFAULT_TRANSDATA_5HD_LIST]; + uint64_t dstOffsetArr[DEFAULT_TRANSDATA_5HD_LIST]; + uint64_t srcAddr = (uint64_t)src.GetPhyAddr(); + uint64_t dstAddr = (uint64_t)dst.GetPhyAddr(); + uint64_t tmpAddr = (uint64_t)tmp.GetPhyAddr(); + for (uint32_t j = 0; j < n1; j++) { + if (n0 - currN > 0) { + for (uint8_t i = 0; i < currN; i++) { + dstOffsetArr[i] = (uint64_t)(dstAddr + (j * d * c * padHw * n0 + i * c * d * padHw) * sizeof(T)); + } + for (uint8_t i = currN; i < DEFAULT_TRANSDATA_5HD_LIST; i++) { + dstOffsetArr[i] = (uint64_t)(tmpAddr + i * ONE_BLK_SIZE * sizeof(T)); + } + } else { + for (uint8_t i = 0; i < DEFAULT_TRANSDATA_5HD_LIST; i++) { + dstOffsetArr[i] = (uint64_t)(dstAddr + (j * d * c * padHw * n0 + i * c * d * padHw) * sizeof(T)); + } + } + for (uint8_t i = 0; i < DEFAULT_TRANSDATA_5HD_LIST; i++) { + srcOffsetArr[i] = (uint64_t)(srcAddr + (j * n0 + i * n0 * n1) * sizeof(T)); + } + TransDataTo5HD(dstOffsetArr, srcOffsetArr, transDataParams); + currN -= n0; + } + PipeBarrier(); +} + +template +__aicore__ inline void N1n0C1c0DHWToNCDHW(const LocalTensor& dst, const LocalTensor& src, + const TransDataTmpParams& params) +{ + // N1N0 C1C0 D H W -> N C D H W + int32_t n = params.n; + int32_t c = params.c; + int32_t d = params.d; + int32_t c1 = params.c1; + int32_t padHw = params.padHw; + + uint16_t blockCount = n; + uint16_t blockLen = (c * (d * padHw)) * sizeof(T) / ONE_BLK_SIZE; + uint16_t srcGap = ((c1 * c0 - c) * (d * padHw)) * sizeof(T) /ONE_BLK_SIZE; + uint16_t dstGap = 0; + DataCopyParams dataCopyParams = { blockCount, blockLen, srcGap, dstGap }; + DataCopy(dst, src, dataCopyParams); + PipeBarrier(); +} + +template +__aicore__ inline void TransDataFractalToNcdhw(const LocalTensor& dst, const LocalTensor& src, + const LocalTensor& tmpBuffer, const TransDataTmpParams& params) +{ + int32_t d = params.d; + int32_t n1 = params.n1; + int32_t c1 = params.c1; + int32_t padHw = params.padHw; + int32_t n = params.n; + int32_t c = params.c; + + LocalTensor tmp = tmpBuffer.template ReinterpretCast(); + LocalTensor srcTmp = src.template ReinterpretCast(); + if (c == c1 * c0 && n == n1 * n0) { + LocalTensor dstTmp = dst.template ReinterpretCast(); + // D C1 HWN1N0C0 -> C1 D HWN1N0C0 (H*W 32B ALIGN -> HW1*HW0) + DC1Hwn1n0c0ToC1DHwn1n0c0HWAlign(dstTmp, srcTmp, params); + // C1 DHWN1N0 C0 -> C1 C0 DHWN1N0 + C1Dhwn1n0c0ToC1C0Dhwn1n0(tmp, dstTmp, params); + // C1C0DHW N1N0 -> N CDHW + C1c0dhwN1n0ToNcdhw(dstTmp, tmp, tmp, params); + } else { + LocalTensor transDataTmp = tmp[n1 * n0 * c1 * c0 * d * padHw]; + LocalTensor dstTmp = dst.template ReinterpretCast(); + // D C1 HWN1N0C0 -> C1 D HWN1N0C0 (H*W 32B ALIGN -> HW1*HW0) + DC1Hwn1n0c0ToC1DHwn1n0c0HWAlign(tmp, srcTmp, params); + // C1 DHWN1N0 C0 -> C1 C0 DHWN1N0 + C1Dhwn1n0c0ToC1C0Dhwn1n0(transDataTmp, tmp, params); + // C1C0DHW N1N0 -> N CDHW + C1c0dhwN1n0ToNcdhw(dstTmp, transDataTmp, tmp, params); + } +} + +// Transdata NCDHW -> FRACTAL_Z_3D +template +__aicore__ inline void TransDataImplNcdhwToFractal(const LocalTensor& dst, const LocalTensor& src, const LocalTensor& tmpBuffer, + const TransDataTmpParams& param) +{ + constexpr int32_t elePerBlk = ONE_BLK_SIZE / sizeof(T); + const int32_t n = param.n, c = param.c, d = param.d, h = param.h, w = param.w; + constexpr int32_t c0 = 16; + constexpr int32_t n0 = 16; + const int32_t c1 = DivCeil(c, c0); + const int32_t n1 = DivCeil(n, n0); + int32_t padHw = AlignUp(h * w, elePerBlk); + int32_t currAxis = c * d * padHw; + Duplicate(tmpBuffer.ReinterpretCast(), static_cast(0), currAxis); + PipeBarrier(); + auto tmpDstTensor = tmpBuffer[currAxis * sizeof(T)].ReinterpretCast(); + uint64_t dstLocalList[DEFAULT_TRANSDATA_5HD_LIST]; + uint64_t srcLocalList[DEFAULT_TRANSDATA_5HD_LIST]; + + uint64_t dstTensorAddr = (uint64_t)dst.GetPhyAddr(); + uint64_t srcTensorAddr = (uint64_t)src.GetPhyAddr(); + uint64_t tmpDstTensorAddr = (uint64_t)tmpDstTensor.GetPhyAddr(); + uint64_t tmpBufferAddr = (uint64_t)tmpBuffer.GetPhyAddr(); + // step1, NCDHW -> CDHW, N1, N0 + // Do n1 times Transpose to split axis N, and fill with 0 on padding data. + TransDataTo5HDParams transDataParams; + transDataParams.dstHighHalf = false; + transDataParams.srcHighHalf = false; + transDataParams.repeatTimes = currAxis / elePerBlk; + // if repeat = 1, start offset is auto incremental by stride. + transDataParams.dstRepStride = transDataParams.repeatTimes == 1 ? 0 : n1 * n0; + transDataParams.srcRepStride = transDataParams.repeatTimes == 1 ? 0 : 1; + + bool isPadded = padHw != h * w; + // dst tensor is unable to fill all padded data. + auto tmpIfPadAddr = isPadded ? tmpDstTensorAddr : dstTensorAddr; + for (int j = 0; j < n1; j++) { + uint64_t currDstAddr = tmpIfPadAddr + j * n0 * sizeof(T); + uint64_t currSrcAddr = srcTensorAddr + j * currAxis * n0 * sizeof(T); + // handle the last axis if N is not even splited by n0. + int remain = j == n1 - 1 ? n - j * n0 : n0; + for (int32_t i = 0; i < n0; i++) { + dstLocalList[i] = currDstAddr + (i * n1 * n0) * sizeof(T); + } + for (int32_t i = 0; i < remain; i++) { + srcLocalList[i] = currSrcAddr + i * currAxis * sizeof(T); + } + for (int32_t i = remain; i < n0; i++) { + srcLocalList[i] = tmpBufferAddr; + } + TransDataTo5HD(dstLocalList, srcLocalList, transDataParams); + } + PipeBarrier(); + // step1.5 collapse padded H,W axis for CDHW, N1N0 + DataCopyParams copyParams; + if (isPadded) { + currAxis = h * w * n1 * n0; + copyParams.blockCount = c * d; + copyParams.blockLen = currAxis / elePerBlk; + // Merge axis by skiping padded H,W. + copyParams.srcStride = (padHw - h * w) * n1 * n0 / elePerBlk; + copyParams.dstStride = 0; + DataCopy(dst, tmpDstTensor, copyParams); + } + PipeBarrier(); + + // step2, CDHWN1N0 -> C1DHW, N1N0, C0 + currAxis = d * h * w * n1 * n0; + transDataParams.repeatTimes = currAxis / elePerBlk; + transDataParams.dstRepStride = transDataParams.repeatTimes == 1 ? 0 : c0; + transDataParams.srcRepStride = transDataParams.repeatTimes == 1 ? 0 : 1; + for (int32_t j = 0; j < c1; j++) { + uint64_t currDstAddr = tmpDstTensorAddr + j * currAxis * c0 * sizeof(T); + uint64_t currSrcAddr = dstTensorAddr + j * currAxis * c0 * sizeof(T); + int remain = j == c1 - 1 ? c - j * c0 : c0; + for (int32_t i = 0; i < DEFAULT_TRANSDATA_5HD_LIST; i++) { + dstLocalList[i] = currDstAddr + i * c0 * sizeof(T); + } + for (int32_t i = 0; i < remain; i++) { + srcLocalList[i] = currSrcAddr + i * currAxis * sizeof(T); + } + for (int32_t i = remain; i < DEFAULT_TRANSDATA_5HD_LIST; i++) { + srcLocalList[i] = tmpBufferAddr; + } + TransDataTo5HD(dstLocalList, srcLocalList, transDataParams); + } + PipeBarrier(); + // steo3 C1DHW, N1N0, C0 -> DC1HW, N1N0, C0 + currAxis = c0 * h * w * n1 * n0; + copyParams.blockCount = d; + copyParams.blockLen = currAxis / elePerBlk; + // Merge axis by skiping padding padHW -> h, w + copyParams.srcStride = 0; + copyParams.dstStride = (c1 - 1) * currAxis / elePerBlk; + for (int32_t i = 0; i < c1; i++) { + DataCopy(dst[i * currAxis], tmpDstTensor[i * d * currAxis], copyParams); + } + PipeBarrier(); +} + +// Transdata NCDHW -> NDC1HWC0 +template +__aicore__ inline void TransDataImplNcdhwTo6Hd(const LocalTensor& dst, const LocalTensor& src, const LocalTensor& tmpBuffer, + const TransDataTmpParams& param) +{ + constexpr int32_t c0 = 16; + constexpr int32_t elePerBlk = ONE_BLK_SIZE / sizeof(T); + const int32_t n = param.n, c = param.c, d = param.d, h = param.h, w = param.w; + const int32_t c1 = DivCeil(c, c0); + const int32_t padHw = AlignUp(h * w, elePerBlk); + int32_t currAxis = d * padHw; + + int32_t axisHwd = h * w * d; + int32_t axisHwc0 = h * w * c0; + int32_t axisC1hwc0 = axisHwc0 * c1; + int32_t axisC1hwdc0 = axisC1hwc0 * d; + int32_t axisPadHwd = padHw * d; + int32_t axisPadHwc0 = padHw * c0; + int32_t axisPadHwdc0 = padHw * c0 * d; + Duplicate(tmpBuffer.ReinterpretCast(), static_cast(0), axisPadHwd); + PipeBarrier(); + + // reserve for padded 0 on additional axis c. + auto tmpDstTensor = tmpBuffer[axisPadHwd * sizeof(T)].ReinterpretCast(); + + uint64_t dstTensorAddr = (uint64_t)dst.GetPhyAddr(); + uint64_t srcTensorAddr = (uint64_t)src.GetPhyAddr(); + uint64_t tmpDstTensorAddr = (uint64_t)tmpDstTensor.GetPhyAddr(); + uint64_t tmpBufferAddr = (uint64_t)tmpBuffer.GetPhyAddr(); + uint64_t dstLocalList[DEFAULT_TRANSDATA_5HD_LIST]; + uint64_t srcLocalList[DEFAULT_TRANSDATA_5HD_LIST]; + TransDataTo5HDParams transDataParams; + transDataParams.dstHighHalf = false; + transDataParams.srcHighHalf = false; + transDataParams.repeatTimes = axisPadHwd / elePerBlk; + transDataParams.dstRepStride = transDataParams.repeatTimes == 1 ? 0 : c0; + transDataParams.srcRepStride = transDataParams.repeatTimes == 1 ? 0 : 1; + + DataCopyParams copyParams; + copyParams.blockCount = d; + copyParams.blockLen = axisHwc0 / elePerBlk; + copyParams.srcStride = (padHw - h * w) * c0 / elePerBlk; + copyParams.dstStride = (c1 - 1) * axisHwc0 / elePerBlk; + // iterates N times CDHW -> C1DHWC0 + for (int32_t k = 0; k < n; k++) { + int32_t currSrcStart = k * axisPadHwd * c; + int32_t currDstStart = k * axisC1hwdc0; + // it's impossible to have calculation size exceed max 255 repeats due to the total memory size. + // step1, CDHW -> C1DHWC0 with pad data + for (int32_t j = 0; j < c1; j++) { + uint64_t currDstAddr = tmpDstTensorAddr + j * axisPadHwdc0 * sizeof(T); + uint64_t currSrcAddr = srcTensorAddr + (currSrcStart + j * axisPadHwdc0) * sizeof(T); + int remain = j == c1 - 1 ? c - j * c0 : c0; + for (int32_t i = 0; i < DEFAULT_TRANSDATA_5HD_LIST; i++) { + dstLocalList[i] = currDstAddr + i * c0 * sizeof(T); + } + for (int32_t i = 0; i < remain; i++) { + srcLocalList[i] = currSrcAddr + i * axisPadHwd * sizeof(T); + } + for (int32_t i = remain; i < DEFAULT_TRANSDATA_5HD_LIST; i++) { + srcLocalList[i] = tmpBufferAddr; + } + TransDataTo5HD(dstLocalList, srcLocalList, transDataParams); + } + PipeBarrier(); + // step2, C1DHWC0 -> DC1HWC0 + for (int32_t i = 0; i < c1; i++) { + DataCopy(dst[currDstStart + i * axisHwc0], tmpDstTensor[i * axisPadHwdc0], copyParams); + } + PipeBarrier(); + } +} + +// Transdata NDC1HWC0 -> NCDHW +template +__aicore__ inline void TransDataImpl6HdToNcdhw(const LocalTensor& dst, const LocalTensor& src, const LocalTensor& tmpBuffer, + const TransDataTmpParams& param) +{ + const int32_t n = param.n, c = param.c, d = param.d, h = param.h, w = param.w; + constexpr int32_t c0 = 16; + constexpr int32_t elePerBlk = ONE_BLK_SIZE / sizeof(T); + const int32_t c1 = DivCeil(c, c0); + const int32_t padHw = AlignUp(h * w, elePerBlk); + constexpr int32_t reservedDummy = 512; + auto tmpDstTensor = tmpBuffer[reservedDummy].template ReinterpretCast(); + uint64_t dstLocalList[DEFAULT_TRANSDATA_5HD_LIST]; + uint64_t srcLocalList[DEFAULT_TRANSDATA_5HD_LIST]; + + uint64_t dstTensorAddr = (uint64_t)dst.GetPhyAddr(); + uint64_t tmpDstTensorAddr = (uint64_t)tmpDstTensor.GetPhyAddr(); + uint64_t tmpBufferAddr = (uint64_t)tmpBuffer.GetPhyAddr(); + + int32_t axisHwd = h * w * d; + int32_t axisHwc0 = h * w * c0; + int32_t axisC1hwc0 = axisHwc0 * c1; + int32_t axisC1hwdc0 = axisC1hwc0 * d; + int32_t axisPadHwd = padHw * d; + int32_t axisPadHwc0 = padHw * c0; + int32_t axisPadHwdc0 = padHw * c0 * d; + TransDataTo5HDParams transDataParams; + transDataParams.dstHighHalf = false; + transDataParams.srcHighHalf = false; + transDataParams.repeatTimes = padHw * d / elePerBlk; + transDataParams.srcRepStride = transDataParams.repeatTimes == 1 ? 0 : c0; + transDataParams.dstRepStride = transDataParams.repeatTimes == 1 ? 0 : 1; + + DataCopyParams copyParams; + copyParams.blockCount = c1; + copyParams.blockLen = h * w * c0 / elePerBlk; + copyParams.srcStride = 0; + copyParams.dstStride = (d * padHw - h * w) * c0 / elePerBlk; + // iterates N times C1DHWC0 -> CDHW + for (int32_t k = 0; k < n; k++) { + // step1 DC1HWC0 -> C1DHWC0 + int32_t currSrcStart = k * axisC1hwdc0; + int32_t currDstStart = k * axisPadHwd * c; + for (int32_t i = 0; i < d; i++) { + DataCopy(tmpDstTensor[i * axisPadHwc0], src[currSrcStart + i * axisC1hwc0], copyParams); + } + PipeBarrier(); + // step2, C1DHWC0 -> C1C0DHW + // it's impossible to have calculation size exceed max 255 repeats due to the total memory size. + for (int32_t j = 0; j < c1; j++) { + int32_t remain = j == c1 - 1 ? c - j * c0 : c0; + uint64_t currDstAddr = dstTensorAddr + (currDstStart + j * axisPadHwdc0) * sizeof(T); + uint64_t currSrcAddr = tmpDstTensorAddr + j * axisPadHwdc0 * sizeof(T); + for (int32_t i = 0; i < remain; i++) { + dstLocalList[i] = currDstAddr + i * axisPadHwd * sizeof(T); + } + for (int32_t i = remain; i < DEFAULT_TRANSDATA_5HD_LIST; i++) { + // temp for reserve redundant data. + dstLocalList[i] = tmpBufferAddr + i * ONE_BLK_SIZE; + } + for (int32_t i = 0; i < DEFAULT_TRANSDATA_5HD_LIST; i++) { + srcLocalList[i] = currSrcAddr + i * c0 * sizeof(T); + } + TransDataTo5HD(dstLocalList, srcLocalList, transDataParams); + } + PipeBarrier(); + } +} + +template +__aicore__ inline void TransDataCheck(const TransDataParams& params) +{ + static_assert(SupportType(), + "Currents only supports half/bfloat16_t/uint16_t/int16_t types."); + static_assert(is_layout_v, "srcLayout must be a layout"); + static_assert(is_layout_v, "dstLayout must be a layout"); + using SrcShapeTuple = Std::remove_cvref_t; + using DstShapeTuple = Std::remove_cvref_t; + static_assert(Std::is_tuple_v, "srcLayout.GetShape() must be a shape."); + static_assert(Std::is_tuple_v, "dstLayout.GetShape() must be a shape."); +} + +template +__aicore__ inline void TransDataImpl(const LocalTensor& dstTensor, const LocalTensor& srcTensor, + const LocalTensor& sharedTmpBuffer, const TransDataParams& params) +{ + TransDataCheck(params); + auto srcShape = params.srcLayout.GetShape(); + auto dstShape = params.dstLayout.GetShape(); + constexpr uint32_t srcShapeSize = static_cast(Std::tuple_size::value); + constexpr uint32_t dstShapeSize = static_cast(Std::tuple_size::value); + CHECK_FUNC_HIGHLEVEL_API(TransData, (config, T, U, S), (dstTensor, srcTensor, sharedTmpBuffer, params)); + using srcType = decltype(srcShape); + using dstType = decltype(dstShape); + using ncdhwType = Std::conditional_t; + ncdhwType ncdhwShape; + if constexpr (config.srcFormat == DataFormat::NCDHW) { + ncdhwShape = params.srcLayout.GetShape(); + } else { + ncdhwShape = params.dstLayout.GetShape(); + } + int32_t n = Std::get<0>(ncdhwShape); + int32_t c = Std::get<1>(ncdhwShape); + int32_t d = Std::get<2>(ncdhwShape); + int32_t h = Std::get<3>(ncdhwShape); + int32_t w = Std::get<4>(ncdhwShape); + int32_t n1 = (n + n0 - 1) / n0; + int32_t c1 = (c + c0 - 1) / c0; + int32_t hw1 = (h * w + hw0 - 1) / hw0; + int32_t padHw = hw1 * hw0; + TransDataTmpParams tmpParams = { n, c, d, h, w, n1, c1, padHw }; + if constexpr (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::FRACTAL_Z_3D) { + static_assert(srcShapeSize == ncdhwDims, "srcLayout's shape dims must be equal to 5!"); + static_assert(dstShapeSize == fractalZ3DDims, "dstLayout's shape dims must be equal to 7!"); + TransDataImplNcdhwToFractal(dstTensor, srcTensor, sharedTmpBuffer, tmpParams); + } else if constexpr (config.srcFormat == DataFormat::FRACTAL_Z_3D && config.dstFormat == DataFormat::NCDHW) { + static_assert(srcShapeSize == fractalZ3DDims, "srcLayout's shape dims must be equal to 7!"); + static_assert(dstShapeSize == ncdhwDims, "dstLayout's shape dims must be equal to 5!"); + TransDataFractalToNcdhw(dstTensor, srcTensor, sharedTmpBuffer, tmpParams); + } else if constexpr (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::NDC1HWC0) { + static_assert(srcShapeSize == ncdhwDims, "srcLayout's shape dims must be equal to 5!"); + static_assert(dstShapeSize == ndc1hwc0Dims, "dstLayout's shape dims must be equal to 6!"); + TransDataImplNcdhwTo6Hd(dstTensor, srcTensor, sharedTmpBuffer, tmpParams); + } else if constexpr (config.srcFormat == DataFormat::NDC1HWC0 && config.dstFormat == DataFormat::NCDHW) { + static_assert(srcShapeSize == ndc1hwc0Dims, "srcLayout's shape dims must be equal to 6!"); + static_assert(dstShapeSize == ncdhwDims, "dstLayout's shape dims must be equal to 5!"); + TransDataImpl6HdToNcdhw(dstTensor, srcTensor, sharedTmpBuffer, tmpParams); + } +} + +} // namespace Internal +} // namespace AscendC +#endif // IMPL_TRANSPOSE_TRANSDATA_TRANSDATA_IMPL_H \ No newline at end of file diff --git a/impl/transpose/transdata/transdata_tiling.cpp b/impl/transpose/transdata/transdata_tiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbfc281e48b3c8e9770a2f55bb2bea36d545ae38 --- /dev/null +++ b/impl/transpose/transdata/transdata_tiling.cpp @@ -0,0 +1,197 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "lib/transpose/transdata_tiling.h" + +#include +#include + +#include "graph/tensor.h" +#include "impl/host_log.h" +#include "tiling/platform/platform_ascendc.h" +namespace AscendC { +namespace { +constexpr int32_t PAD_ELE_FOR_HALF = 16; +constexpr int32_t N_INDEX = 0; +constexpr int32_t C_INDEX = 1; +constexpr int32_t D_INDEX = 2; +constexpr int32_t H_INDEX = 3; +constexpr int32_t W_INDEX = 4; + +struct TmpTransDataParams { + int32_t n = 0; + int32_t c = 0; + int32_t d = 0; + int32_t h = 0; + int32_t w = 0; +}; + +int32_t DivCeil(int32_t a, int32_t b) +{ + if (b == 0) { + return a; + } + return (a + b - 1) / b; +} + +int32_t AlignUp(int32_t a, int32_t b) +{ + return DivCeil(a, b) * b; +} + +bool GenerateFractalZ3DToNcdhwShapeInfo(const std::vector& dstDims, const std::vector& srcDims, + TmpTransDataParams ¶m, const int32_t c0, const int32_t n0) +{ + ASCENDC_HOST_ASSERT(srcDims.size() == 7 && dstDims.size() == 5, return false, + "[TransData][GetTransDataMaxMinTmpSize] input shapes are not matched with DataFormat."); + param.n = dstDims[N_INDEX]; + param.c = dstDims[C_INDEX]; + param.d = dstDims[D_INDEX]; + param.h = dstDims[H_INDEX]; + param.w = dstDims[W_INDEX]; + // validate d, h, w + ASCENDC_HOST_ASSERT(param.d == srcDims[0] && param.h == srcDims[2] && param.w == srcDims[3], return false, + "[TransData][GetTransDataMaxMinTmpSize] shapeInfo d,h,w is not matched."); + ASCENDC_HOST_ASSERT(srcDims[6] == c0 && srcDims[1] * c0 == AlignUp(param.c, c0), return false, + "[TransData][GetTransDataMaxMinTmpSize] src c0, c1 is not able to be converted to c."); + ASCENDC_HOST_ASSERT(srcDims[5] == n0 && srcDims[4] * n0 == AlignUp(param.n, n0), return false, + "[TransData][GetTransDataMaxMinTmpSize] src n0, n1 is not able to be converted to n."); + return true; +} + +bool GenerateNcdhwToFractalZ3DShapeInfo(const std::vector& dstDims, const std::vector& srcDims, + TmpTransDataParams ¶m, const int32_t c0, const int32_t n0) +{ + ASCENDC_HOST_ASSERT(srcDims.size() == 5 && dstDims.size() == 7, return false, + "[TransData][GetTransDataMaxMinTmpSize] input shapes are not matched with DataFormat."); + param.n = srcDims[N_INDEX]; + param.c = srcDims[C_INDEX]; + param.d = srcDims[D_INDEX]; + param.h = srcDims[H_INDEX]; + param.w = srcDims[W_INDEX]; + // validate d, h, w + ASCENDC_HOST_ASSERT(param.d == dstDims[0] && param.h == dstDims[2] && param.w == dstDims[3], return false, + "[TransData][GetTransDataMaxMinTmpSize] shapeInfo d,h,w is not matched."); + ASCENDC_HOST_ASSERT(dstDims[6] == c0 && dstDims[1] * c0 == AlignUp(param.c, c0), return false, + "[TransData][GetTransDataMaxMinTmpSize] dst c0, c1 is not able to be converted to c."); + ASCENDC_HOST_ASSERT(dstDims[5] == n0 && dstDims[4] * n0 == AlignUp(param.n, n0), return false, + "[TransData][GetTransDataMaxMinTmpSize] dst n0, n1 is not able to be converted to n."); + return true; +} + +bool GenerateShapeInfo(const TransDataConfig &config, const ge::Shape &srcShape, const ge::Shape &dstShape, ge::DataType type, + TmpTransDataParams ¶m) +{ + (void)type; + constexpr int32_t c0 = 16, n0 = 16; + std::vector srcDims = srcShape.GetDims(); + std::vector dstDims = dstShape.GetDims(); + if (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::NDC1HWC0) { + ASCENDC_HOST_ASSERT(srcDims.size() == 5 && dstDims.size() == 6, return false, + "[TransData][GetTransDataMaxMinTmpSize] input shapes are not matched with DataFormat."); + param.n = srcDims[N_INDEX]; + param.c = srcDims[C_INDEX]; + param.d = srcDims[D_INDEX]; + param.h = srcDims[H_INDEX]; + param.w = srcDims[W_INDEX]; + // validate n, d, h, w + ASCENDC_HOST_ASSERT(param.n == dstDims[0] && param.d == dstDims[1] && param.h == dstDims[3] && param.w == dstDims[4], + return false, "[TransData][GetTransDataMaxMinTmpSize] shapeInfo n,d,h,w is not matched."); + ASCENDC_HOST_ASSERT(dstDims[5] == c0 && dstDims[2] * c0 == AlignUp(param.c, c0), return false, + "[TransData][GetTransDataMaxMinTmpSize] dst c0, c1 is not able to be converted to c."); + return true; + } + if (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::FRACTAL_Z_3D) { + return GenerateNcdhwToFractalZ3DShapeInfo(dstDims, srcDims, param, c0, n0); + } + if (config.srcFormat == DataFormat::FRACTAL_Z_3D && config.dstFormat == DataFormat::NCDHW) { + return GenerateFractalZ3DToNcdhwShapeInfo(dstDims, srcDims, param, c0, n0); + } + if (config.srcFormat == DataFormat::NDC1HWC0 && config.dstFormat == DataFormat::NCDHW) { + ASCENDC_HOST_ASSERT(srcDims.size() == 6 && dstDims.size() == 5, return false, + "[TransData][GetTransDataMaxMinTmpSize] input shapes are not matched with DataFormat."); + param.n = dstDims[N_INDEX]; + param.c = dstDims[C_INDEX]; + param.d = dstDims[D_INDEX]; + param.h = dstDims[H_INDEX]; + param.w = dstDims[W_INDEX]; + // validate n, d, h, w + ASCENDC_HOST_ASSERT(param.n == srcDims[0] && param.d == srcDims[1] && param.h == srcDims[3] && param.w == srcDims[4], + return false, "[TransData][GetTransDataMaxMinTmpSize] shapeInfo n,d,h,w is not matched."); + ASCENDC_HOST_ASSERT(srcDims[5] == c0 && srcDims[2] * c0 == AlignUp(param.c, c0), return false, + "[TransData][GetTransDataMaxMinTmpSize] src c0, c1 is not able to be converted to c."); + return true; + } + return false; +} + +int32_t GetTmpBufferSize(const TransDataConfig &config, const TmpTransDataParams ¶m) +{ + constexpr int32_t dataSize = 2; + int32_t n = param.n, c = param.c, d = param.d, h = param.h, w = param.w; + constexpr int32_t c0 = 16, n0 = 16; + int32_t c1 = DivCeil(c, c0), n1 = DivCeil(n, n0); + int32_t padHw = AlignUp(h * w, 16); + if (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::NDC1HWC0) + { + return d * padHw * dataSize + d * c1 * c0 * padHw * dataSize; + } + if (config.srcFormat == DataFormat::NDC1HWC0 && config.dstFormat == DataFormat::NCDHW) + { + constexpr int32_t redundantDataBuffer = 512; + return d * c1 * c0 * padHw * dataSize + redundantDataBuffer; + } + if (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::FRACTAL_Z_3D) + { + return c * d * padHw * dataSize + n1 * n0 * d * c1 * c0 * padHw * dataSize; + } + if (config.srcFormat == DataFormat::FRACTAL_Z_3D && config.dstFormat == DataFormat::NCDHW) + { + constexpr int32_t doubleTmpSize = 2; + if (n == n0 * n1 && c == c0 * c1) { + return n1 * n0 * c1 * c0 * d * padHw * dataSize; + } + return n1 * n0 * c1 * c0 * d * padHw * dataSize * doubleTmpSize; + } + return 0; +} +} // namespace + +bool GetTransDataMaxMinTmpSize(const platform_ascendc::PlatformAscendC &platform, + const ge::Shape &srcShape, + const ge::Shape &dstShape, + const ge::DataType dataType, + const TransDataConfig &config, + uint32_t &maxValue, uint32_t &minValue) +{ + ASCENDC_HOST_ASSERT(dataType == ge::DataType::DT_FLOAT16 || dataType == ge::DataType::DT_BF16 || + dataType == ge::DataType::DT_UINT16 || dataType == ge::DataType::DT_INT16, return false, + "[TransData][GetTransDataMaxMinTmpSize] it only supports DT_FLOAT16/DT_BF16/DT_UINT16/DT_INT16 data type"); + + ASCENDC_HOST_ASSERT(((config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::FRACTAL_Z_3D) || + (config.srcFormat == DataFormat::FRACTAL_Z_3D && config.dstFormat == DataFormat::NCDHW) || + (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::NDC1HWC0) || + (config.srcFormat == DataFormat::NDC1HWC0 && config.dstFormat == DataFormat::NCDHW)), return false, + "[TransData][GetTransDataMaxMinTmpSize] The parameter config srcFormat/dstFormat only supports " + "(NCDHW, FRACTAL_Z_3D)/(FRACTAL_Z_3D, NCDHW)/(NCDHW, NDC1HWC0)/(NDC1HWC0, NCDHW)!"); + + platform_ascendc::SocVersion socVersion = platform.GetSocVersion(); + ASCENDC_HOST_ASSERT(socVersion == platform_ascendc::SocVersion::ASCEND910B, return false, + "[TransData][GetTransDataMaxMinTmpSize] Unsupported SocVersion for TransData API."); + + TmpTransDataParams tmpParam; + + ASCENDC_HOST_ASSERT(GenerateShapeInfo(config, srcShape, dstShape, dataType, tmpParam), return false, + "[TransData][GetTransDataMaxMinTmpSize] failed to validate inputs informations."); + maxValue = GetTmpBufferSize(config, tmpParam); + minValue = maxValue; + return true; +} +} // namespace AscendC diff --git a/lib/reduce/reduce_tiling.h b/lib/reduce/reduce_tiling.h index a43f6f8330bc2f7d1e5199e28fca4ce6eed5ca16..c27591e4a087be179fc7a2a5b3d2fc409b582493 100644 --- a/lib/reduce/reduce_tiling.h +++ b/lib/reduce/reduce_tiling.h @@ -15,6 +15,7 @@ #ifndef LIB_REDUCE_REDUCE_TILING_H #define LIB_REDUCE_REDUCE_TILING_H #include +#include "graph/types.h" #include "graph/tensor.h" namespace AscendC { diff --git a/lib/sort/topk_tiling.h b/lib/sort/topk_tiling.h index 07e36dbbfde70517691d0f9b7a7b41a3f483d62e..e08d1af42d502deb8543af313f56bdaafe053328 100644 --- a/lib/sort/topk_tiling.h +++ b/lib/sort/topk_tiling.h @@ -11,6 +11,7 @@ #define LIB_SORT_TOPK_TILING_H #include "topk_tilingdata.h" #include "tiling/platform/platform_ascendc.h" +#include "graph/types.h" namespace AscendC { diff --git a/lib/transpose/transdata.h b/lib/transpose/transdata.h new file mode 100644 index 0000000000000000000000000000000000000000..c0075cf5276392e5ca5399c7e122203059d393ca --- /dev/null +++ b/lib/transpose/transdata.h @@ -0,0 +1,48 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LIB_TRANSPOSE_TRANSDATA_H +#define LIB_TRANSPOSE_TRANSDATA_H +#if __CCE_AICORE__ == 220 +#include "transdata_common.h" +#include "kernel_tensor.h" +#include "kernel_operator_intf.h" +#include "kernel_pop_stack_buffer.h" +#include "../../impl/transpose/transdata/transdata_impl.h" +#if ASCENDC_CPU_DEBUG +#include "kernel_log.h" +#include +#endif + +namespace AscendC { + +template +__aicore__ inline void TransData(const LocalTensor& dstTensor, const LocalTensor& srcTensor, + const LocalTensor& sharedTmpBuffer, const TransDataParams& params) +{ + Internal::TransDataImpl(dstTensor, srcTensor, sharedTmpBuffer, params); +} + +template +__aicore__ inline void TransData(const LocalTensor& dstTensor, const LocalTensor& srcTensor, + const TransDataParams& params) +{ + // Only for AI Vector Core. + if ASCEND_IS_AIC { + return; + } + LocalTensor tmp; + const bool ret = PopStackBuffer(tmp); + ASCENDC_ASSERT((ret), { KERNEL_LOG(KERNEL_ERROR, "PopStackBuffer Error!"); }); + + TransData(dstTensor, srcTensor, tmp, params); +} +} // namespace AscendC +#endif +#endif // LIB_TRANSPOSE_TRANSDATA_H \ No newline at end of file diff --git a/lib/transpose/transdata_common.h b/lib/transpose/transdata_common.h new file mode 100644 index 0000000000000000000000000000000000000000..0421a3ca724cf4f6feaae1ab14e7f8fe3e89692d --- /dev/null +++ b/lib/transpose/transdata_common.h @@ -0,0 +1,29 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LIB_TRANSPOSE_TRANSDATA_COMMON_H +#define LIB_TRANSPOSE_TRANSDATA_COMMON_H + +namespace AscendC { +template +struct TransDataParams { + T srcLayout; + U dstLayout; +}; + +#ifndef ASCC_PARAM_TRANSDATACONFIG +#define ASCC_PARAM_TRANSDATACONFIG +struct TransDataConfig { + DataFormat srcFormat; + DataFormat dstFormat; +}; +#endif // ASCC_PARAM_TRANSDATACONFIG +} // namespace AscendC + +#endif // LIB_TRANSPOSE_TRANSDATA_COMMON_H \ No newline at end of file diff --git a/lib/transpose/transdata_tiling.h b/lib/transpose/transdata_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..f2d722218a7c5556d30910c293a1f808c54b3d7e --- /dev/null +++ b/lib/transpose/transdata_tiling.h @@ -0,0 +1,67 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file transdata_tiling.h + * \brief + */ +#ifndef LIB_TRANSPOSE_TRANSDATA_TILING_H +#define LIB_TRANSPOSE_TRANSDATA_TILING_H +#include +#include "graph/tensor.h" +#include "tiling/platform/platform_ascendc.h" + +namespace AscendC { +/* + * @brief DataFormat +*/ +#ifndef ASCC_ENUM_DATAFORMAT +#define ASCC_ENUM_DATAFORMAT +enum class DataFormat : uint8_t { + ND = 0, + NZ, + NCHW, + NC1HWC0, + NHWC, + NCDHW, + NDC1HWC0, + FRACTAL_Z_3D, +}; +#endif // ASCC_ENUM_DATAFORMAT + +#ifndef ASCC_PARAM_TRANSDATACONFIG +#define ASCC_PARAM_TRANSDATACONFIG +struct TransDataConfig { + DataFormat srcFormat; + DataFormat dstFormat; +}; +#endif // ASCC_PARAM_TRANSDATACONFIG + +/*! + * \brief This interface is used to obtain the maximum and minimum temporary space reserved or applied. + * The developer selects a proper space size based on this range as the tiling parameter. + * + * \param [in] platform, targeted platform information + * \param [in] srcShape, src tensor shape + * \param [in] dstShape, src tensor shape + * \param [in] dataType, actual data type of the input + * \param [in] config, transdata config + * \param [out] maxValue, maximum temporary space required + * \param [out] minValue, minimum temporary space required + * \return whether get the max/min value successfully + */ +bool GetTransDataMaxMinTmpSize(const platform_ascendc::PlatformAscendC &platform, + const ge::Shape &srcShape, + const ge::Shape &dstShape, + const ge::DataType dataType, + const TransDataConfig &config, + uint32_t &maxValue, uint32_t &minValue); +} // AscendC +#endif // LIB_TRANSPOSE_TRANSDATA_TILING_H \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e1f9cd3763fd4813a21ef92eb4d2b6ca7de408c2..a721159336fcde0e830a02746b01f0956ad97038 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -154,6 +154,7 @@ file(GLOB ASCENDC_TEST_ascend910B1_AIV_CASE_SRC_FILES ${ASCENDC_TESTS_DIR}/normalization/welfordfinalize/test_operator_welfordfinalize.cpp ${ASCENDC_TESTS_DIR}/utils/init_global_memory/test_operator_init_global_memory.cpp ${ASCENDC_TESTS_DIR}/normalization/layernormV2/test_operator_layernormV2.cpp + ${ASCENDC_TESTS_DIR}/transpose/transdata/*cpp ) # ascend910B1 aic test cases @@ -405,6 +406,7 @@ file(GLOB ASCENDC_TILING_SRC_FILES ${ASCENDC_API_DIR}/impl/sort/topk/*.cpp ${ASCENDC_API_DIR}/impl/reduce/reduce_tiling.cpp ${ASCENDC_API_DIR}/impl/normalization/layernormV2/*.cpp + ${ASCENDC_API_DIR}/impl/transpose/transdata/transdata_tiling.cpp ) # ascendc_tiling_utest diff --git a/tests/tiling/test_tiling.cpp b/tests/tiling/test_tiling.cpp index 50d4faa0fc247d615512ec2c2502a7f6967fa3e5..b28356198233bce88a10eacf6c8e1f4b75c44a52 100644 --- a/tests/tiling/test_tiling.cpp +++ b/tests/tiling/test_tiling.cpp @@ -32,7 +32,7 @@ protected: void TearDown() {} }; - +extern void platfrom_stub_set_chip_version(const char *num); TEST_F(TestTiling, MultiCoreSmallMN) { matmul_tiling::MultiCoreMatmulTiling rnnMatmul3,rnnMatmul4,rnnMatmul5; @@ -5400,6 +5400,106 @@ TEST_F(TestTiling, TestOneElementBroadCast200) } #endif +TEST_F(TestTiling, testTransDataTilingUnalignedHw) +{ + platfrom_stub_set_chip_version("Ascend910B"); + uint32_t maxSize; + uint32_t minSize; + int32_t n = 16; + int32_t c = 16; + int32_t d = 3; + int32_t h = 3; + int32_t w = 3; + int32_t c0 = 16; + int32_t n0 = 16; + int32_t c1 = (c + c0 - 1) / c0; + int32_t n1 = (n + n0 - 1) / n0; + int32_t hw0 = 16; + int32_t hw1 = (h * w + hw0 - 1) / hw0; + auto ncdhwShape = ge::Shape({ n, c, d, h, w }); + auto ndc1hwc0Shape = ge::Shape({ n, d, c1, h, w, c0}); + auto fractalzShape = ge::Shape({ d, c1, h, w, n1, n0, c0}); + fe::PlatFormInfos platform_info; + auto plat = platform_ascendc::PlatformAscendC(&platform_info); + TransDataConfig config = {DataFormat::NCDHW, DataFormat::NDC1HWC0}; + bool ret = GetTransDataMaxMinTmpSize(plat, ncdhwShape, ndc1hwc0Shape, ge::DataType::DT_FLOAT16, config, maxSize, minSize); + + EXPECT_TRUE(ret); + EXPECT_EQ(maxSize, 1632); + EXPECT_EQ(minSize, 1632); + + config = {DataFormat::NDC1HWC0, DataFormat::NCDHW}; + ret = GetTransDataMaxMinTmpSize(plat, ndc1hwc0Shape, ncdhwShape, ge::DataType::DT_FLOAT16, config, maxSize, minSize); + + EXPECT_TRUE(ret); + EXPECT_EQ(maxSize, 2048); + EXPECT_EQ(minSize, 2048); + + config = {DataFormat::NCDHW, DataFormat::FRACTAL_Z_3D}; + ret = GetTransDataMaxMinTmpSize(plat, ncdhwShape, fractalzShape, ge::DataType::DT_FLOAT16, config, maxSize, minSize); + + EXPECT_TRUE(ret); + EXPECT_EQ(maxSize, 26112); + EXPECT_EQ(minSize, 26112); + + config = {DataFormat::FRACTAL_Z_3D, DataFormat::NCDHW}; + ret = GetTransDataMaxMinTmpSize(plat, fractalzShape, ncdhwShape, ge::DataType::DT_FLOAT16, config, maxSize, minSize); + + EXPECT_TRUE(ret); + EXPECT_EQ(maxSize, n1 * n0 * c1 * c0 * d * hw0 * hw1 * 2); + EXPECT_EQ(minSize, n1 * n0 * c1 * c0 * d * hw0 * hw1 * 2); +} + +TEST_F(TestTiling, testTransDataTilingAlignedHw) +{ + platfrom_stub_set_chip_version("Ascend910B"); + uint32_t maxSize; + uint32_t minSize; + int32_t n = 5; + int32_t c = 30; + int32_t d = 2; + int32_t h = 4; + int32_t w = 8; + int32_t c0 = 16; + int32_t n0 = 16; + int32_t c1 = (c + c0 - 1) / c0; + int32_t n1 = (n + n0 - 1) / n0; + int32_t hw0 = 16; + int32_t hw1 = (h * w + hw0 - 1) / hw0; + auto ncdhwShape = ge::Shape({ n, c, d, h, w }); + auto ndc1hwc0Shape = ge::Shape({ n, d, c1, h, w, c0}); + auto fractalzShape = ge::Shape({ d, c1, h, w, n1, n0, c0}); + fe::PlatFormInfos platform_info; + auto plat = platform_ascendc::PlatformAscendC(&platform_info); + TransDataConfig config = {DataFormat::NCDHW, DataFormat::NDC1HWC0}; + bool ret = GetTransDataMaxMinTmpSize(plat, ncdhwShape, ndc1hwc0Shape, ge::DataType::DT_FLOAT16, config, maxSize, minSize); + + EXPECT_TRUE(ret); + EXPECT_EQ(maxSize, 4224); + EXPECT_EQ(minSize, 4224); + + config = {DataFormat::NDC1HWC0, DataFormat::NCDHW}; + ret = GetTransDataMaxMinTmpSize(plat, ndc1hwc0Shape, ncdhwShape, ge::DataType::DT_FLOAT16, config, maxSize, minSize); + + EXPECT_TRUE(ret); + EXPECT_EQ(maxSize, 4608); + EXPECT_EQ(minSize, 4608); + + config = {DataFormat::NCDHW, DataFormat::FRACTAL_Z_3D}; + ret = GetTransDataMaxMinTmpSize(plat, ncdhwShape, fractalzShape, ge::DataType::DT_FLOAT16, config, maxSize, minSize); + + EXPECT_TRUE(ret); + EXPECT_EQ(maxSize, 69376); + EXPECT_EQ(minSize, 69376); + + config = {DataFormat::FRACTAL_Z_3D, DataFormat::NCDHW}; + ret = GetTransDataMaxMinTmpSize(plat, fractalzShape, ncdhwShape, ge::DataType::DT_FLOAT16, config, maxSize, minSize); + + EXPECT_TRUE(ret); + EXPECT_EQ(maxSize, n1 * n0 * c1 * c0 * d * hw0 * hw1 * 2 * 2); + EXPECT_EQ(minSize, n1 * n0 * c1 * c0 * d * hw0 * hw1 * 2 * 2); +} + TEST_F(TestTiling, TestReduceXorSumTilingInt16) { std::vector shapeDims = { 128, 128 }; diff --git a/tests/transpose/transdata/test_operator_transdata.cpp b/tests/transpose/transdata/test_operator_transdata.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d50408a97e8ab49666de9d50fb3fb65d5d74e982 --- /dev/null +++ b/tests/transpose/transdata/test_operator_transdata.cpp @@ -0,0 +1,267 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "kernel_operator.h" + +#include +#include + +namespace AscendC { + +namespace { + +constexpr uint32_t NCDHW_FractalZ3D = 1; +constexpr uint32_t FractalZ3D_NCDHW = 2; +constexpr uint32_t NCDHW_NDC1HWC0 = 3; +constexpr uint32_t NDC1HWC0_NCDHW = 4; + + +constexpr TransDataConfig config1 = {DataFormat::NCDHW, DataFormat::FRACTAL_Z_3D}; +constexpr TransDataConfig config2 = {DataFormat::FRACTAL_Z_3D, DataFormat::NCDHW}; +constexpr TransDataConfig config3 = {DataFormat::NCDHW, DataFormat::NDC1HWC0}; +constexpr TransDataConfig config4 = {DataFormat::NDC1HWC0, DataFormat::NCDHW}; + +} + +template +class KernelTransData { +public: +__aicore__ inline KernelTransData() {} +__aicore__ inline void Init(GM_ADDR srcGm, GM_ADDR dstGm, + int32_t n, int32_t c, int32_t d, int32_t h, int32_t w, TPipe *tpipe) +{ + this->d = d; + this->c = c; + this->h = h; + this->w = w; + this->n = n; + this->c1 = (c + c0 - 1) / c0; + this->n1 = (n + n0 - 1) / n0; + this->hw1 = (h*w + hw0 - 1) / hw0; + + if (mode == NDC1HWC0_NCDHW) { + this->srcShapeSize = n * c1 * c0 * d * h * w; + this->dstShapeSize = n * d * c * hw0; + this->tmpShapeSize = 512 + d * c1 * c0 * hw0 * hw1; + uint32_t dstGmSize = n * c * d * h * w; + srcGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGm), srcShapeSize * sizeof(T)); + dstGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGm), dstGmSize * sizeof(T)); + } else { + if constexpr (mode == NCDHW_FractalZ3D) { + srcShapeSize = n * d * c * hw0 * hw1; + dstShapeSize = n1 * n0 * c1 * c0 * d * h * w; + if ((h*w) % 16 != 0 ) { + needPad = true; + dstShapeSize = n1 * n0 * c1 * c0 * d * hw0 * hw1; + } + tmpShapeSize = c * d * hw0 * hw1 + n1 * n0 * d * c1 * c0 * hw0 * hw1; + } else if constexpr (mode == FractalZ3D_NCDHW) { + this->srcShapeSize = d * c1 * h * w * n1 * n0 * c0; + this->dstShapeSize = n * c * d * (hw1 * hw0); + this->tmpShapeSize = d * c1 * (hw1 * hw0) * n1 * n0 * c0 * 2; + } else if constexpr (mode == NCDHW_NDC1HWC0) { + this->srcShapeSize = n * d * c * hw0; + this->dstShapeSize = n * c1 * c0 * d * h * w; + this->tmpShapeSize = d * hw0 * hw1 + d * c1 * c0 * hw0 * hw1; + } + srcGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGm), srcShapeSize * sizeof(T)); + dstGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGm), dstShapeSize * sizeof(T)); + } + + + this->pipe = tpipe; + pipe->InitBuffer(inQueue, 1, srcShapeSize * sizeof(T)); + pipe->InitBuffer(outQueue, 1, dstShapeSize * sizeof(T)); + pipe->InitBuffer(tmpBuf, tmpShapeSize * sizeof(T)); + +} +__aicore__ inline void Process() +{ + CopyIn(); + Compute(); + CopyOut(); +} + +private: +__aicore__ inline void CopyIn() +{ + LocalTensor srcLocal = inQueue.AllocTensor(); + if constexpr (mode == NCDHW_FractalZ3D || mode == NCDHW_NDC1HWC0) { + DataCopyExtParams extParam = {static_cast(n * c * d), + static_cast(h * w * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams padParam = {true, 0, 0, 0}; + if (needPad) { + DataCopyPad(srcLocal, srcGlobal, extParam, padParam); + } else { + DataCopy(srcLocal, srcGlobal, srcShapeSize); + } + } else if constexpr (mode == FractalZ3D_NCDHW || mode == NDC1HWC0_NCDHW) { + DataCopy(srcLocal, srcGlobal, srcShapeSize); + } + + inQueue.EnQue(srcLocal); +} +__aicore__ inline void Compute() +{ + LocalTensor dstLocal = outQueue.AllocTensor(); + LocalTensor tmp = tmpBuf.Get(); + LocalTensor srcLocal = inQueue.DeQue(); + PipeBarrier(); + + Layout ncdhwLayout = MakeLayout(MakeShape(n, c, d, h, w), MakeStride()); + Layout ndc1hwc0Layout = MakeLayout(MakeShape(n, d, c1, h, w, c0), MakeStride()); + Layout fractalLayout = MakeLayout(MakeShape(d, c1, h, w, n1, n0, c0), MakeStride()); + + if constexpr (mode == NCDHW_FractalZ3D) { + TransDataParams params = {ncdhwLayout, fractalLayout}; + TransData(dstLocal, srcLocal, tmp, params); + } else if constexpr (mode == FractalZ3D_NCDHW) { + TransDataParams params = {fractalLayout, ncdhwLayout}; + TransData(dstLocal, srcLocal, tmp, params); + } else if constexpr (mode == NCDHW_NDC1HWC0) { + TransDataParams params = {ncdhwLayout, ndc1hwc0Layout}; + TransData(dstLocal, srcLocal, tmp, params); + } else if constexpr (mode == NDC1HWC0_NCDHW) { + TransDataParams params = {ndc1hwc0Layout, ncdhwLayout}; + TransData(dstLocal, srcLocal, tmp, params); + } + + outQueue.EnQue(dstLocal); + inQueue.FreeTensor(srcLocal); + +} +__aicore__ inline void CopyOut() +{ + LocalTensor dstLocal = outQueue.DeQue(); + DataCopyExtParams extParam {static_cast(n * c * d), static_cast(h*w*sizeof(T)), 0, 0, 0}; + if constexpr (mode == NCDHW_FractalZ3D) { + DataCopy(dstGlobal, dstLocal, n1 * n0 * c1); + } else if constexpr (mode == FractalZ3D_NCDHW) { + DataCopy(dstGlobal, dstLocal, dstShapeSize); + } else if constexpr (mode == NCDHW_NDC1HWC0) { + DataCopy(dstGlobal, dstLocal, dstShapeSize); + } else if constexpr (mode == NDC1HWC0_NCDHW) { + DataCopyPad(dstGlobal, dstLocal, extParam); + } + outQueue.FreeTensor(dstLocal); +} + +private: + GlobalTensor srcGlobal; + GlobalTensor dstGlobal; + TPipe *pipe; + TQue inQueue; + TQue outQueue; + TBuf tmpBuf; + bool needPad = false; + int32_t n = 0; + int32_t c = 0; + int32_t d = 0; + int32_t h = 0; + int32_t w = 0; + int32_t n1 = 0; + int32_t c1 = 0; + int32_t hw1 = 0; + int32_t c0 = 16; + int32_t n0 = 16; + int32_t hw0 = 16; + uint32_t srcShapeSize = 0; + uint32_t dstShapeSize = 0; + uint32_t tmpShapeSize = 0; +}; +} // namespace AscendC + +template +__global__ __aicore__ void MainTransdata( + __gm__ uint8_t* dstGm, __gm__ uint8_t* srcGm, uint64_t n, uint64_t c, uint64_t d, uint64_t h, uint64_t w) +{ + if (g_coreType == AscendC::AIC || AscendC::GetBlockIdx() > 0) { + return; + } + AscendC::TPipe pipe; + AscendC::KernelTransData op; + op.Init(srcGm, dstGm, n, c, d, h, w, &pipe); + op.Process(); +} + +struct TransDataTestParams { + int32_t n; + int32_t c; + int32_t d; + int32_t h; + int32_t w; + uint32_t mode; + void (*cal_func)(uint8_t*, uint8_t*, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t); +}; + +class TransDataTestsuite : public testing::Test, public testing::WithParamInterface { +protected: + void SetUp() + { + AscendC::SetGCoreType(2); + } + + void TearDown() + { + AscendC::SetGCoreType(0); + } +}; + +INSTANTIATE_TEST_CASE_P(TEST_OPERATTION_TRANSDATA, TransDataTestsuite, + ::testing::Values( + TransDataTestParams { 5, 32, 2, 1, 16, 1, MainTransdata }, + TransDataTestParams { 4, 31, 1, 6, 7, 2, MainTransdata }, + TransDataTestParams { 4, 20, 2, 3, 1, 3, MainTransdata }, + TransDataTestParams { 8, 14, 2, 1, 16, 4, MainTransdata }, + TransDataTestParams { 5, 32, 2, 1, 16, 1, MainTransdata }, + TransDataTestParams { 4, 31, 1, 6, 7, 2, MainTransdata }, + TransDataTestParams { 4, 20, 2, 3, 1, 3, MainTransdata }, + TransDataTestParams { 8, 14, 2, 1, 16, 4, MainTransdata } + + )); + +TEST_P(TransDataTestsuite, TransDataOpTestCase) +{ + auto params = GetParam(); + auto n = params.n; + auto c = params.c; + auto d = params.d; + auto h = params.h; + auto w = params.w; + auto mode = params.mode; + uint32_t srcShapeSize; + uint32_t dstShapeSize; + int32_t hw0 = 16; + int32_t hw1 = (h * w + hw0 - 1) / hw0; + int32_t c0 = 16; + int32_t n0 = 16; + int32_t c1 = (c + c0 - 1) / c0; + int32_t n1 = (n + n0 - 1) / n0; + if (mode == 1) { + srcShapeSize = n * d * c * hw0 * hw1; + dstShapeSize = n1 * n0 * c1 * c0 * d * h * w; + if ((h*w) % 16 != 0 ) { + dstShapeSize = n1 * n0 * c1 * c0 * d * hw0 * hw1; + } + } else if (mode == 2) { + srcShapeSize = d * c1 * h * w * n1 * n0 * c0; + dstShapeSize = n * c * d * (hw1 * hw0); + } else if (mode == 3) { + srcShapeSize = n * d * c * hw0; + dstShapeSize = n * c1 * c0 * d * h * w; + } else if (mode == 4) { + srcShapeSize = n * c1 * c0 * d * h * w; + dstShapeSize = n * d * c * hw0; + } + uint8_t srcGm[srcShapeSize * sizeof(half)] = {0}; // 外部保证inner是32B对齐 + uint8_t dstGm[dstShapeSize * sizeof(half)] = {0}; + params.cal_func(dstGm, srcGm, n, c, d, h, w); + EXPECT_EQ(dstGm[0], 0); +}