diff --git a/impl/matmul/tiling/matmul_tiling_algorithm.cpp b/impl/matmul/tiling/matmul_tiling_algorithm.cpp index 14145e00883bc2bf5b0cab57b2f4fdc532e75898..9bcfac18474dec6ec7583fd5feaafc3f82b0ec89 100644 --- a/impl/matmul/tiling/matmul_tiling_algorithm.cpp +++ b/impl/matmul/tiling/matmul_tiling_algorithm.cpp @@ -3383,7 +3383,9 @@ void MatmulTilingAlgorithm::GetSingleShape(const CoreStatusPack &coreStatus, con int32_t bAlignSize = DATA_COPY_ALIGN_SIZE / DTYPE_BIT_TAB.at(tilingIns_->bType_.dataType) * BITS_PER_BYTE; auto multiCoreScenario = GetMultiCoreScenario(param); bool needAlign = multiCoreScenario == MultiCoreScenario::SPLIT_MN || - multiCoreScenario == MultiCoreScenario::SPLIT_SMALL_MN; + multiCoreScenario == MultiCoreScenario::SPLIT_SMALL_MN || + tilingIns_->aType_.type == CubeFormat::NZ || + tilingIns_->bType_.type == CubeFormat::NZ; bool needOutputAlign = NeedOutputAlign(singleCoreM, singleCoreN, singleCoreK); (void)AlignSingleShape(needAlign && (!tilingIns_->bType_.isTrans || needOutputAlign), param.n32 * C0_SIZE, coreStatus.nDim, bAlignSize, singleCoreN); diff --git a/tests/tiling/test_matmul_api_tiling.cpp b/tests/tiling/test_matmul_api_tiling.cpp index d3da1fcc509128d293f2f13ca2f7b2f45cca396d..ee08bc04828f84d3df39ad5e9423832a933fb61c 100644 --- a/tests/tiling/test_matmul_api_tiling.cpp +++ b/tests/tiling/test_matmul_api_tiling.cpp @@ -24,6 +24,33 @@ protected: virtual void SetUp() {} void TearDown() {} }; + +TEST_F(TestMatmulAPITiling, TestMatmulApiTilngMultiCoreNZINOUT) +{ + optiling::TCubeTiling tilingData; + matmul_tiling::MultiCoreMatmulTiling tilingApi; + tilingApi.SetDim(10); + tilingApi.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::NZ, matmul_tiling::DataType::DT_FLOAT, true); + tilingApi.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::NZ, matmul_tiling::DataType::DT_FLOAT); + tilingApi.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::NZ, matmul_tiling::DataType::DT_FLOAT); + tilingApi.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); + tilingApi.SetOrgShape(736, 7168, 4096); + tilingApi.SetShape(736, 7168, 4096); + tilingApi.EnableBias(true); + tilingApi.SetBufferSpace(-1, -1, -1); + + int64_t res = tilingApi.GetTiling(tilingData); + EXPECT_EQ(res, 0); + tilingApi.PrintTilingData(); + + int32_t shapeM = 1; + int32_t shapeN = 1; + int32_t shapeK = 1; + (void)tilingApi.GetSingleShape(shapeM, shapeN, shapeK); + bool isAlign = shapeN % 16 == 0 ? true : false; + EXPECT_EQ(isAlign, true); +} + TEST_F(TestMatmulAPITiling, TestMatmulApiTilngMultiCoreBTSCM) { optiling::TCubeTiling tilingData;