diff --git a/impl/activation/softmax/softmax_tiling.cpp b/impl/activation/softmax/softmax_tiling.cpp index d213175ee4a3c7de2deada79fb3d806384471195..cbedff82bc4514d5eb2d41d9ec741f8eb37fadc6 100644 --- a/impl/activation/softmax/softmax_tiling.cpp +++ b/impl/activation/softmax/softmax_tiling.cpp @@ -138,6 +138,14 @@ void SoftMaxTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize, c softmaxTiling.set_tailReduceSize(tail * elementNumPerBlk); } +void SoftMaxTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize, const uint32_t localWorkSpaceSize, + SoftMaxTiling& softmaxTiling) +{ + optiling::SoftMaxTiling tiling; + SoftMaxTilingFunc(srcShape, dataTypeSize, localWorkSpaceSize, tiling); + tiling.SaveToBuffer(&softmaxTiling, sizeof(SoftMaxTiling)); +} + uint32_t GetSoftMaxFlashMaxTmpSize(const ge::Shape& srcShape, const uint32_t dataTypeSize, const bool isUpdate, UNUSED const bool isReuseSource) { @@ -224,6 +232,14 @@ void SoftMaxFlashTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSi softmaxFlashTiling.set_tailReduceSize(tail * elementNumPerBlk); } +void SoftMaxFlashTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize, const uint32_t localWorkSpaceSize, + SoftMaxTiling& softmaxFlashTiling, const bool isUpdate) +{ + optiling::SoftMaxTiling tiling; + SoftMaxFlashTilingFunc(srcShape, dataTypeSize, localWorkSpaceSize, tiling, isUpdate); + tiling.SaveToBuffer(&softmaxFlashTiling, sizeof(SoftMaxTiling)); +} + uint32_t GetSoftMaxGradMaxTmpSize(const ge::Shape& srcShape, const uint32_t dataTypeSize, const bool isFront, UNUSED const bool isReuseSource) { @@ -325,6 +341,14 @@ void SoftMaxGradTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSiz softmaxGradTiling.set_tailReduceSize(tail * elementNumPerBlk); } +void SoftMaxGradTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize, const uint32_t localWorkSpaceSize, + SoftMaxTiling& softmaxGradTiling, const bool isFront) +{ + optiling::SoftMaxTiling tiling; + SoftMaxGradTilingFunc(srcShape, dataTypeSize, localWorkSpaceSize, tiling, isFront); + tiling.SaveToBuffer(&softmaxGradTiling, sizeof(SoftMaxTiling)); +} + bool IsBasicBlockInSoftMax(optiling::SoftMaxTiling& tiling, const uint32_t dataTypeSize) { constexpr uint32_t SOFTMAX_BASICBLOCK_MAX_SIZE = 2048; @@ -341,6 +365,14 @@ bool IsBasicBlockInSoftMax(optiling::SoftMaxTiling& tiling, const uint32_t dataT } } +bool IsBasicBlockInSoftMax(SoftMaxTiling& tiling, const uint32_t dataTypeSize) +{ + optiling::SoftMaxTiling opTiling; + opTiling.set_srcK(tiling.srcK); + opTiling.set_tailM(tiling.tailM); + return IsBasicBlockInSoftMax(opTiling, dataTypeSize); +} + uint32_t GetSoftMaxFlashV2MaxTmpSize(const ge::Shape& srcShape, const uint32_t dataTypeSize1, const uint32_t dataTypeSize2, UNUSED const bool isUpdate, const bool isBasicBlock, UNUSED const bool isFlashOutputBrc) @@ -503,6 +535,16 @@ void SoftMaxFlashV2TilingFunc(const ge::Shape& srcShape, const uint32_t dataType } } +void SoftMaxFlashV2TilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize1, const uint32_t dataTypeSize2, + const uint32_t localWorkSpaceSize, SoftMaxTiling& softmaxFlashTiling, UNUSED const bool isUpdate, + const bool isBasicBlock, const bool isFlashOutputBrc) +{ + optiling::SoftMaxTiling tiling; + SoftMaxFlashV2TilingFunc(srcShape, dataTypeSize1, dataTypeSize2, localWorkSpaceSize, tiling, isUpdate, + isBasicBlock, isFlashOutputBrc); + tiling.SaveToBuffer(&softmaxFlashTiling, sizeof(SoftMaxTiling)); +} + void GetSoftMaxFlashV3MaxMinTmpSize(const ge::Shape &srcShape, const uint32_t dataTypeSize1, const uint32_t dataTypeSize2, uint32_t &maxValue, uint32_t &minValue, UNUSED const bool isUpdate, UNUSED const bool isBasicBlock) @@ -576,4 +618,15 @@ void SoftMaxFlashV3TilingFunc(const ge::Shape &srcShape, const uint32_t dataType softmaxFlashV3Tiling.set_tailReduceSize(tail * elementNumPerBlk); } } + +void SoftMaxFlashV3TilingFunc(const ge::Shape &srcShape, const uint32_t dataTypeSize1, const uint32_t dataTypeSize2, + const uint32_t localWorkSpaceSize, SoftMaxTiling &softmaxFlashV3Tiling, UNUSED const bool isUpdate, + UNUSED const bool isBasicBlock) +{ + optiling::SoftMaxTiling tiling; + SoftMaxFlashV3TilingFunc(srcShape, dataTypeSize1, dataTypeSize2, localWorkSpaceSize, tiling, isUpdate, + isBasicBlock); + tiling.SaveToBuffer(&softmaxFlashV3Tiling, sizeof(SoftMaxTiling)); +} + } \ No newline at end of file diff --git a/lib/activation/softmax_tiling.h b/lib/activation/softmax_tiling.h index 3b166f2a12667758c7ceaed8f8f5bdfe38389bc3..7333b1bf6471f73792405d94bb2d62521207dd7e 100644 --- a/lib/activation/softmax_tiling.h +++ b/lib/activation/softmax_tiling.h @@ -17,6 +17,7 @@ #define LIB_SOFTMAX_SOFTMAX_TILING_H #include "graph/tensor.h" #include "softmax_tilingdata.h" +#include "kernel_tiling/kernel_tiling.h" namespace AscendC { /* * @ingroup GetSoftMaxMaxTmpSize @@ -46,6 +47,9 @@ uint32_t GetSoftMaxMinTmpSize(const ge::Shape& srcShape, const uint32_t dataType */ void SoftMaxTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize, const uint32_t localWorkSpaceSize, optiling::SoftMaxTiling& softmaxTiling); + +void SoftMaxTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize, const uint32_t localWorkSpaceSize, + SoftMaxTiling& softmaxTiling); /* * @ingroup GetSoftMaxFlashV3MaxMinTmpSize * @brief calculate SoftmaxFlashV3 api need min/max temporary local space size @@ -72,6 +76,10 @@ void GetSoftMaxFlashV3MaxMinTmpSize(const ge::Shape& srcShape, const uint32_t da void SoftMaxFlashV3TilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize1, const uint32_t dataTypeSize2, const uint32_t localWorkSpaceSize, optiling::SoftMaxTiling& softmaxFlashV3Tiling, const bool isUpdate, const bool isBasicBlock = false); + +void SoftMaxFlashV3TilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize1, const uint32_t dataTypeSize2, + const uint32_t localWorkSpaceSize, SoftMaxTiling& softmaxFlashV3Tiling, const bool isUpdate, + const bool isBasicBlock = false); /* * @ingroup GetSoftMaxFlashMaxTmpSize * @brief calculate SoftmaxFlash api need max temporary local space size @@ -106,6 +114,9 @@ uint32_t GetSoftMaxFlashMinTmpSize(const ge::Shape& srcShape, const uint32_t dat void SoftMaxFlashTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize, const uint32_t localWorkSpaceSize, optiling::SoftMaxTiling& softmaxFlashTiling, const bool isUpdate = false); +void SoftMaxFlashTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize, const uint32_t localWorkSpaceSize, + SoftMaxTiling& softmaxFlashTiling, const bool isUpdate = false); + /* * @ingroup GetSoftMaxGradMaxTmpSize * @brief get SoftmaxGrad api need max temporary local space size @@ -140,6 +151,8 @@ uint32_t GetSoftMaxGradMinTmpSize(const ge::Shape& srcShape, const uint32_t data void SoftMaxGradTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize, const uint32_t localWorkSpaceSize, optiling::SoftMaxTiling& softmaxGradTiling, const bool isFront = false); +void SoftMaxGradTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize, const uint32_t localWorkSpaceSize, + SoftMaxTiling& softmaxGradTiling, const bool isFront = false); /* * @ingroup IsBasicBlockInSoftMax * @brief judge tiling is basicBlock or not @@ -149,6 +162,7 @@ void SoftMaxGradTilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSiz */ bool IsBasicBlockInSoftMax(optiling::SoftMaxTiling& tiling, const uint32_t dataTypeSize = 2); +bool IsBasicBlockInSoftMax(SoftMaxTiling& tiling, const uint32_t dataTypeSize = 2); /* * @ingroup GetSoftMaxFlashV2MinTmpSize * @brief get SoftmaxFlashV2 api need min temporary local space size @@ -192,5 +206,9 @@ uint32_t GetSoftMaxFlashV2MaxTmpSize(const ge::Shape& srcShape, const uint32_t d void SoftMaxFlashV2TilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize1, const uint32_t dataTypeSize2, const uint32_t localWorkSpaceSize, optiling::SoftMaxTiling& softmaxFlashTiling, const bool isUpdate, const bool isBasicBlock = false, const bool isFlashOutputBrc = false); + +void SoftMaxFlashV2TilingFunc(const ge::Shape& srcShape, const uint32_t dataTypeSize1, const uint32_t dataTypeSize2, + const uint32_t localWorkSpaceSize, SoftMaxTiling& softmaxFlashTiling, const bool isUpdate, + const bool isBasicBlock = false, const bool isFlashOutputBrc = false); } #endif // LIB_SOFTMAX_SOFTMAX_TILING_H diff --git a/tests/tiling/test_tiling.cpp b/tests/tiling/test_tiling.cpp index 50d4faa0fc247d615512ec2c2502a7f6967fa3e5..34be9f4c0e5ef4ed13e6b76d1bf571026340b59e 100644 --- a/tests/tiling/test_tiling.cpp +++ b/tests/tiling/test_tiling.cpp @@ -1314,6 +1314,16 @@ TEST_F(TestTiling, TestSoftMaxTiling) EXPECT_EQ(tilingData.get_reduceM(), 64); SoftMaxGradTilingFunc(softmaxShape, 2, 133120, tilingData, true); EXPECT_EQ(tilingData.get_reduceM(), 64); + + SoftMaxTiling tilingDataNotOp; + SoftMaxFlashTilingFunc(softmaxShape, 2, 77952, tilingDataNotOp, true); + EXPECT_EQ(tilingDataNotOp.reduceM, 32); + flag = IsBasicBlockInSoftMax(tilingDataNotOp); + EXPECT_EQ(flag, true); + SoftMaxGradTilingFunc(softmaxShape, 2, 133120, tilingDataNotOp, true); + EXPECT_EQ(tilingDataNotOp.reduceM, 64); + SoftMaxTilingFunc(softmaxShape, 2, softmaxTmpSize, tilingDataNotOp); + EXPECT_EQ(tilingDataNotOp.reduceM, 64); } TEST_F(TestTiling, TestLogSoftMaxTiling) @@ -1620,6 +1630,10 @@ TEST_F(TestTiling, TestSoftMaxFlashV2TilingBasicBlock) EXPECT_EQ(softmaxflashV2NeedMinLength, (8 + 512 + 64) * 4 * 16); SoftMaxFlashV2TilingFunc(softmaxShape, inputTypeSize, maxSumTypeSize, workLength, tilingData, true, true, true); EXPECT_EQ(tilingData.get_reduceM(), 16); + + SoftMaxTiling tilingDataNotOp; + SoftMaxFlashV2TilingFunc(softmaxShape, inputTypeSize, maxSumTypeSize, workLength, tilingDataNotOp, true, true, true); + EXPECT_EQ(tilingDataNotOp.reduceM, 16); } TEST_F(TestTiling, TestSoftMaxFlashV3Tiling) @@ -1648,6 +1662,10 @@ TEST_F(TestTiling, TestSoftMaxFlashV3Tiling) uint32_t workLength = 76 * 1024; SoftMaxFlashV3TilingFunc(softmaxShape, inputTypeSize, maxSumTypeSize, workLength, tilingData, true, true); EXPECT_EQ(tilingData.get_reduceM(), 8); + + SoftMaxTiling tilingDataNotOp; + SoftMaxFlashV3TilingFunc(softmaxShape, inputTypeSize, maxSumTypeSize, workLength, tilingDataNotOp, true, true); + EXPECT_EQ(tilingDataNotOp.reduceM, 8); } TEST_F(TestTiling, TestAsinTmpBufferFacotrHalfWithoutBasicBlock) {