From a42eb1169d7dd875c033e5caa046cbd96b12667f Mon Sep 17 00:00:00 2001 From: caixilong <2508418876@qq.com> Date: Mon, 16 Jun 2025 21:03:46 +0800 Subject: [PATCH 1/6] add test case for rope dirty data --- .../unittest/core/test_graph_launch_mode.cpp | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) diff --git a/tests/unittest/core/test_graph_launch_mode.cpp b/tests/unittest/core/test_graph_launch_mode.cpp index fde77cd3..dfe3843e 100644 --- a/tests/unittest/core/test_graph_launch_mode.cpp +++ b/tests/unittest/core/test_graph_launch_mode.cpp @@ -702,4 +702,141 @@ TEST(TestGraphLaunchMode, CapturedByUserAndChangeWorkspace) aclrtFree(workSpace3); aclrtDestroyStream(exeStream); aclrtResetDevice(deviceId); +} + +TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) +{ + const uint32_t batchSize = 1; + const uint32_t nTokens = 4; + const uint32_t hiddenSizeQ = 16; + const uint32_t hiddenSizeK = 16; + const uint32_t headSize = 8; + + if (!atb::GetSingleton().Is910A()) { + GTEST_SKIP() << "This test case does not support 910A"; + } + + uint32_t deviceId = 0; + aclError status = aclrtSetDevice(deviceId); + ASSERT_EQ(status, 0); + atb::infer::RopeParam param; + param.cosFormat = 1; + param.rotaryCoeff = 4; + atb::Operation * op = nullptr; + atb::Status st = atb::CreateOperation(param, &op); + ASSERT_EQ(st, 0); + + atb::Tensor query; + query.desc.dtype = ACL_FLOAT16; + query.desc.format = ACL_FORMAT_ND; + query.desc.shape.dimNum = 2; + query.desc.shape.dims[0] = nTokens; + query.desc.shape.dims[1] = hiddenSizeQ; + query.dataSize = Utils::GetTensorSize(query); + status = aclrtMalloc(&query.deviceData, query.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + atb::Tensor key; + key.desc.dtype = ACL_FLOAT16; + key.desc.format = ACL_FORMAT_ND; + key.desc.shape.dimNum = 2; + key.desc.shape.dims[0] = nTokens; + key.desc.shape.dims[1] = hiddenSizeK; + key.dataSize = Utils::GetTensorSize(key); + status = aclrtMalloc(&key.deviceData, key.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + atb::Tensor cos; + cos.desc.dtype = ACL_FLOAT16; + cos.desc.format = ACL_FORMAT_ND; + cos.desc.shape.dimNum = 2; + cos.desc.shape.dims[0] = nTokens; + cos.desc.shape.dims[1] = headSize; + cos.dataSize = Utils::GetTensorSize(cos); + status = aclrtMalloc(&cos.deviceData, cos.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + atb::Tensor sin; + sin.desc.dtype = ACL_FLOAT16; + sin.desc.format = ACL_FORMAT_ND; + sin.desc.shape.dimNum = 2; + sin.desc.shape.dims[0] = nTokens; + sin.desc.shape.dims[1] = headSize; + sin.dataSize = Utils::GetTensorSize(sin); + status = aclrtMalloc(&sin.deviceData, sin.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + atb::Tensor seqLen; + seqLen.desc.dtype = ACL_INT32; + seqLen.desc.format = ACL_FORMAT_ND; + seqLen.desc.shape.dimNum = 1; + seqLen.desc.shape.dims[0] = batchSize; + seqLen.dataSize = Utils::GetTensorSize(seqLen); + status = aclrtMalloc(&seqLen.deviceData, seqLen.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + atb::Tensor ropeQ; + ropeQ.desc.dtype = ACL_FLOAT16; + ropeQ.desc.format = ACL_FORMAT_ND; + ropeQ.desc.shape.dimNum = 2; + ropeQ.desc.shape.dims[0] = nTokens; + ropeQ.desc.shape.dims[1] = hiddenSizeQ; + ropeQ.dataSize = Utils::GetTensorSize(ropeQ); + status = aclrtMalloc(&ropeQ.deviceData, ropeQ.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + atb::Tensor ropeK; + ropeK.desc.dtype = ACL_FLOAT16; + ropeK.desc.format = ACL_FORMAT_ND; + ropeK.desc.shape.dimNum = 2; + ropeK.desc.shape.dims[0] = nTokens; + ropeK.desc.shape.dims[1] = hiddenSizeK; + ropeK.dataSize = Utils::GetTensorSize(ropeK); + status = aclrtMalloc(&ropeK.deviceData, ropeK.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + atb::VariantPack variantPack; + variantPack.inTensors = {query, key, cos, sin, seqLen}; + variantPack.outTensors = {ropeQ, ropeK}; + + atb::Context *context = nullptr; + st = atb::CreateContext(&context); + ASSERT_EQ(st, 0); + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + ASSERT_EQ(status, 0); + context.SetExecuteStream(stream); + uint64_t workspaceSize = 0; + st = op->Setup(variantPack, workspaceSize, context); + ASSERT_EQ(st, 0); + void *workspace = nullptr; + status = aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + st = op->Execute(variantPack, (uint8_t*)workspace, workspaceSize, context); + ASSERT_EQ(st, 0); + + status = aclrtDestroyStream(stream); + ASSERT_EQ(status, 0); + status = aclrtFree(workspace); + ASSERT_EQ(status, 0); + st = atb::DestroyOperation(op); + ASSERT_EQ(st, 0); + st = atb::DestroyContext(context); + ASSERT_EQ(st, 0); + + for (size_t i = 0; i < variantPack.inTensors.size(); ++i) { + tensor = variantPack.inTensors.at(i); + status.aclrtFree(tensor); + ASSERT_EQ(status, 0); + tensor.deviceData = nullptr; + tensor.dataSize = 0; + } + for (size_t i = 0; i < variantPack.outTensors.size(); ++i) { + tensor = variantPack.outTensors.at(i); + status.aclrtFree(tensor); + ASSERT_EQ(status, 0); + tensor.deviceData = nullptr; + tensor.dataSize = 0; + } + aclrtResetDevice(deviceId); } \ No newline at end of file -- Gitee From 94ac50da15f7162af8ecacb30f2ae60b5efa627a Mon Sep 17 00:00:00 2001 From: caixilong <2508418876@qq.com> Date: Tue, 17 Jun 2025 10:49:39 +0800 Subject: [PATCH 2/6] fix error --- .../unittest/core/test_graph_launch_mode.cpp | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/unittest/core/test_graph_launch_mode.cpp b/tests/unittest/core/test_graph_launch_mode.cpp index dfe3843e..f1df8ee9 100644 --- a/tests/unittest/core/test_graph_launch_mode.cpp +++ b/tests/unittest/core/test_graph_launch_mode.cpp @@ -712,7 +712,7 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) const uint32_t hiddenSizeK = 16; const uint32_t headSize = 8; - if (!atb::GetSingleton().Is910A()) { + if (atb::GetSingleton().Is910A()) { GTEST_SKIP() << "This test case does not support 910A"; } @@ -732,7 +732,7 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) query.desc.shape.dimNum = 2; query.desc.shape.dims[0] = nTokens; query.desc.shape.dims[1] = hiddenSizeQ; - query.dataSize = Utils::GetTensorSize(query); + query.dataSize = atb::Utils::GetTensorSize(query); status = aclrtMalloc(&query.deviceData, query.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); @@ -742,7 +742,7 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) key.desc.shape.dimNum = 2; key.desc.shape.dims[0] = nTokens; key.desc.shape.dims[1] = hiddenSizeK; - key.dataSize = Utils::GetTensorSize(key); + key.dataSize = atb::Utils::GetTensorSize(key); status = aclrtMalloc(&key.deviceData, key.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); @@ -752,7 +752,7 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) cos.desc.shape.dimNum = 2; cos.desc.shape.dims[0] = nTokens; cos.desc.shape.dims[1] = headSize; - cos.dataSize = Utils::GetTensorSize(cos); + cos.dataSize = atb::Utils::GetTensorSize(cos); status = aclrtMalloc(&cos.deviceData, cos.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); @@ -762,7 +762,7 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) sin.desc.shape.dimNum = 2; sin.desc.shape.dims[0] = nTokens; sin.desc.shape.dims[1] = headSize; - sin.dataSize = Utils::GetTensorSize(sin); + sin.dataSize = atb::Utils::GetTensorSize(sin); status = aclrtMalloc(&sin.deviceData, sin.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); @@ -771,7 +771,7 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) seqLen.desc.format = ACL_FORMAT_ND; seqLen.desc.shape.dimNum = 1; seqLen.desc.shape.dims[0] = batchSize; - seqLen.dataSize = Utils::GetTensorSize(seqLen); + seqLen.dataSize = atb::Utils::GetTensorSize(seqLen); status = aclrtMalloc(&seqLen.deviceData, seqLen.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); @@ -781,7 +781,7 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) ropeQ.desc.shape.dimNum = 2; ropeQ.desc.shape.dims[0] = nTokens; ropeQ.desc.shape.dims[1] = hiddenSizeQ; - ropeQ.dataSize = Utils::GetTensorSize(ropeQ); + ropeQ.dataSize = atb::Utils::GetTensorSize(ropeQ); status = aclrtMalloc(&ropeQ.deviceData, ropeQ.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); @@ -791,7 +791,7 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) ropeK.desc.shape.dimNum = 2; ropeK.desc.shape.dims[0] = nTokens; ropeK.desc.shape.dims[1] = hiddenSizeK; - ropeK.dataSize = Utils::GetTensorSize(ropeK); + ropeK.dataSize = atb::Utils::GetTensorSize(ropeK); status = aclrtMalloc(&ropeK.deviceData, ropeK.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); @@ -805,7 +805,7 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) aclrtStream stream = nullptr; status = aclrtCreateStream(&stream); ASSERT_EQ(status, 0); - context.SetExecuteStream(stream); + context->SetExecuteStream(stream); uint64_t workspaceSize = 0; st = op->Setup(variantPack, workspaceSize, context); ASSERT_EQ(st, 0); @@ -825,15 +825,15 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) ASSERT_EQ(st, 0); for (size_t i = 0; i < variantPack.inTensors.size(); ++i) { - tensor = variantPack.inTensors.at(i); - status.aclrtFree(tensor); + atb::Tensor tensor = variantPack.inTensors.at(i); + status = aclrtFree(tensor.deviceData); ASSERT_EQ(status, 0); tensor.deviceData = nullptr; tensor.dataSize = 0; } for (size_t i = 0; i < variantPack.outTensors.size(); ++i) { - tensor = variantPack.outTensors.at(i); - status.aclrtFree(tensor); + atb::Tensor tensor = variantPack.outTensors.at(i); + status = aclrtFree(tensor.deviceData); ASSERT_EQ(status, 0); tensor.deviceData = nullptr; tensor.dataSize = 0; -- Gitee From e7bd5187deaf6b59b323da3dc6ed0e17c19760fa Mon Sep 17 00:00:00 2001 From: caixilong <2508418876@qq.com> Date: Tue, 17 Jun 2025 11:36:02 +0800 Subject: [PATCH 3/6] add synchronize stream --- tests/unittest/core/test_graph_launch_mode.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unittest/core/test_graph_launch_mode.cpp b/tests/unittest/core/test_graph_launch_mode.cpp index f1df8ee9..6b2c9dfd 100644 --- a/tests/unittest/core/test_graph_launch_mode.cpp +++ b/tests/unittest/core/test_graph_launch_mode.cpp @@ -814,6 +814,8 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) ASSERT_EQ(status, 0); st = op->Execute(variantPack, (uint8_t*)workspace, workspaceSize, context); ASSERT_EQ(st, 0); + status = aclrtSynchronizeStream(stream); + ASSERT_EQ(status, 0); status = aclrtDestroyStream(stream); ASSERT_EQ(status, 0); -- Gitee From 910f33264a592d47693b395c715fa3a68a1f9c9a Mon Sep 17 00:00:00 2001 From: caixilong <2508418876@qq.com> Date: Tue, 17 Jun 2025 15:35:39 +0800 Subject: [PATCH 4/6] add value and edit param --- .../unittest/core/test_graph_launch_mode.cpp | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/tests/unittest/core/test_graph_launch_mode.cpp b/tests/unittest/core/test_graph_launch_mode.cpp index 6b2c9dfd..94001a95 100644 --- a/tests/unittest/core/test_graph_launch_mode.cpp +++ b/tests/unittest/core/test_graph_launch_mode.cpp @@ -708,9 +708,9 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) { const uint32_t batchSize = 1; const uint32_t nTokens = 4; - const uint32_t hiddenSizeQ = 16; - const uint32_t hiddenSizeK = 16; - const uint32_t headSize = 8; + const uint32_t hiddenSizeQ = 4096; + const uint32_t hiddenSizeK = 4096; + const uint32_t headSize = 128; if (atb::GetSingleton().Is910A()) { GTEST_SKIP() << "This test case does not support 910A"; @@ -719,11 +719,18 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) uint32_t deviceId = 0; aclError status = aclrtSetDevice(deviceId); ASSERT_EQ(status, 0); + atb::Context *context = nullptr; + atb::Status st = atb::CreateContext(&context); + ASSERT_EQ(st, 0); + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + ASSERT_EQ(status, 0); + context->SetExecuteStream(stream); atb::infer::RopeParam param; param.cosFormat = 1; param.rotaryCoeff = 4; atb::Operation * op = nullptr; - atb::Status st = atb::CreateOperation(param, &op); + st = atb::CreateOperation(param, &op); ASSERT_EQ(st, 0); atb::Tensor query; @@ -733,8 +740,11 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) query.desc.shape.dims[0] = nTokens; query.desc.shape.dims[1] = hiddenSizeQ; query.dataSize = atb::Utils::GetTensorSize(query); + std::vector queryData(atb::Utils::GetTensorNumel(query), 1); status = aclrtMalloc(&query.deviceData, query.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); + status = aclrtMemcpy(query.deviceData, query.dataSize, queryData.data(), queryData.size() * sizeof(uint16_t), ACL_MEMCPY_HOST_TO_DEVICE); + ASSERT_EQ(status, 0); atb::Tensor key; key.desc.dtype = ACL_FLOAT16; @@ -743,8 +753,11 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) key.desc.shape.dims[0] = nTokens; key.desc.shape.dims[1] = hiddenSizeK; key.dataSize = atb::Utils::GetTensorSize(key); + std::vector keyData(atb::Utils::GetTensorNumel(key), 1); status = aclrtMalloc(&key.deviceData, key.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); + status = aclrtMemcpy(key.deviceData, key.dataSize, keyData.data(), keyData.size() * sizeof(uint16_t), ACL_MEMCPY_HOST_TO_DEVICE); + ASSERT_EQ(status, 0); atb::Tensor cos; cos.desc.dtype = ACL_FLOAT16; @@ -753,8 +766,11 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) cos.desc.shape.dims[0] = nTokens; cos.desc.shape.dims[1] = headSize; cos.dataSize = atb::Utils::GetTensorSize(cos); + std::vector cosData(atb::Utils::GetTensorNumel(cos), 1); status = aclrtMalloc(&cos.deviceData, cos.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); + status = aclrtMemcpy(cos.deviceData, cos.dataSize, cosData.data(), cosData.size() * sizeof(uint16_t), ACL_MEMCPY_HOST_TO_DEVICE); + ASSERT_EQ(status, 0); atb::Tensor sin; sin.desc.dtype = ACL_FLOAT16; @@ -763,8 +779,11 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) sin.desc.shape.dims[0] = nTokens; sin.desc.shape.dims[1] = headSize; sin.dataSize = atb::Utils::GetTensorSize(sin); + std::vector sinData(atb::Utils::GetTensorNumel(sin), 1); status = aclrtMalloc(&sin.deviceData, sin.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); + status = aclrtMemcpy(sin.deviceData, sin.dataSize, sinData.data(), sinData.size() * sizeof(uint16_t), ACL_MEMCPY_HOST_TO_DEVICE); + ASSERT_EQ(status, 0); atb::Tensor seqLen; seqLen.desc.dtype = ACL_INT32; @@ -772,8 +791,11 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) seqLen.desc.shape.dimNum = 1; seqLen.desc.shape.dims[0] = batchSize; seqLen.dataSize = atb::Utils::GetTensorSize(seqLen); + std::vector seqLenData(atb::Utils::GetTensorNumel(seqLen), 4); status = aclrtMalloc(&seqLen.deviceData, seqLen.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); + status = aclrtMemcpy(seqLen.deviceData, seqLen.dataSize, seqLenData.data(), seqLenData.size() * sizeof(uint16_t), ACL_MEMCPY_HOST_TO_DEVICE); + ASSERT_EQ(status, 0); atb::Tensor ropeQ; ropeQ.desc.dtype = ACL_FLOAT16; @@ -799,13 +821,6 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) variantPack.inTensors = {query, key, cos, sin, seqLen}; variantPack.outTensors = {ropeQ, ropeK}; - atb::Context *context = nullptr; - st = atb::CreateContext(&context); - ASSERT_EQ(st, 0); - aclrtStream stream = nullptr; - status = aclrtCreateStream(&stream); - ASSERT_EQ(status, 0); - context->SetExecuteStream(stream); uint64_t workspaceSize = 0; st = op->Setup(variantPack, workspaceSize, context); ASSERT_EQ(st, 0); -- Gitee From 206ce5dc9aed266a805b0b25c5f63c3bc0e05840 Mon Sep 17 00:00:00 2001 From: caixilong <2508418876@qq.com> Date: Tue, 17 Jun 2025 15:43:38 +0800 Subject: [PATCH 5/6] fix bug --- tests/unittest/core/test_graph_launch_mode.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittest/core/test_graph_launch_mode.cpp b/tests/unittest/core/test_graph_launch_mode.cpp index 94001a95..2c69a292 100644 --- a/tests/unittest/core/test_graph_launch_mode.cpp +++ b/tests/unittest/core/test_graph_launch_mode.cpp @@ -794,7 +794,7 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) std::vector seqLenData(atb::Utils::GetTensorNumel(seqLen), 4); status = aclrtMalloc(&seqLen.deviceData, seqLen.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); - status = aclrtMemcpy(seqLen.deviceData, seqLen.dataSize, seqLenData.data(), seqLenData.size() * sizeof(uint16_t), ACL_MEMCPY_HOST_TO_DEVICE); + status = aclrtMemcpy(seqLen.deviceData, seqLen.dataSize, seqLenData.data(), seqLenData.size() * sizeof(int32_t), ACL_MEMCPY_HOST_TO_DEVICE); ASSERT_EQ(status, 0); atb::Tensor ropeQ; -- Gitee From b739682db00c9bbcec93df28f6afb5181f315a4f Mon Sep 17 00:00:00 2001 From: caixilong <2508418876@qq.com> Date: Thu, 19 Jun 2025 20:37:20 +0800 Subject: [PATCH 6/6] add graph mode testcase --- .../unittest/core/test_graph_launch_mode.cpp | 299 ++++++++++++++++++ 1 file changed, 299 insertions(+) diff --git a/tests/unittest/core/test_graph_launch_mode.cpp b/tests/unittest/core/test_graph_launch_mode.cpp index 2c69a292..bb026450 100644 --- a/tests/unittest/core/test_graph_launch_mode.cpp +++ b/tests/unittest/core/test_graph_launch_mode.cpp @@ -13,6 +13,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -827,6 +830,11 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) void *workspace = nullptr; status = aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); ASSERT_EQ(status, 0); + + //填充脏数据 + status = aclrtMemset(workspace, workspaceSize, 0xFF, workspaceSize); + ASSERT_EQ(status, 0); + st = op->Execute(variantPack, (uint8_t*)workspace, workspaceSize, context); ASSERT_EQ(st, 0); status = aclrtSynchronizeStream(stream); @@ -841,6 +849,297 @@ TEST(TestGraphLaunchMode, RopeWorkspaceFullOfDirtyData) st = atb::DestroyContext(context); ASSERT_EQ(st, 0); + for (size_t i = 0; i < variantPack.inTensors.size(); ++i) { + atb::Tensor tensor = variantPack.inTensors.at(i); + status = aclrtFree(tensor.deviceData); + ASSERT_EQ(status, 0); + tensor.deviceData = nullptr; + tensor.dataSize = 0; + } + for (size_t i = 0; i < variantPack.outTensors.size(); ++i) { + atb::Tensor tensor = variantPack.outTensors.at(i); + status = aclrtFree(tensor.deviceData); + ASSERT_EQ(status, 0); + tensor.deviceData = nullptr; + tensor.dataSize = 0; + } + aclrtResetDevice(deviceId); +} + +TEST(TestGraphLaunchMode, GatingWorkspaceFullOfDirtyData) +{ + const uint32_t tokenNum = 512; + const uint32_t expertNum = 1024; + const uint32_t topKNum = 4; + + if (atb::GetSingleton().Is910A() || atb::GetSingleton().Is310B()) { + GTEST_SKIP() << "This test case does not support 910A and 310B"; + } + + uint32_t deviceId = 0; + aclError status = aclrtSetDevice(deviceId); + ASSERT_EQ(status, 0); + atb::Context *context = nullptr; + atb::Status st = atb::CreateContext(&context); + ASSERT_EQ(st, 0); + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + ASSERT_EQ(status, 0); + context->SetExecuteStream(stream); + atb::infer::GatingParam param; + param.topkExpertNum = topKNum; + param.cumSumNum = 1024; + param.cumSumInt64 = false; + param.deviceExpert = std::vector(); + atb::Operation * op = nullptr; + st = atb::CreateOperation(param, &op); + ASSERT_EQ(st, 0); + + atb::Tensor topK; + topK.desc.dtype = ACL_INT32; + topK.desc.format = ACL_FORMAT_ND; + topK.desc.shape.dimNum = 1; + topK.desc.shape.dims[0] = tokenNum * topKNum; + topK.dataSize = atb::Utils::GetTensorSize(topK); + std::vector topKData(atb::Utils::GetTensorNumel(topK), 0); + std::random_device rd; + std::mt19937 gen(rd()); + for (int i = 0; i < tokenNum; ++i) { + std::vector pool(expertNum); + std::iota(pool.begin(), pool.end(), 0); + std::shuffle(pool.begin(), pool.end(), gen); + for (int j = 0; j < topKNum; ++j) { + topKData[i * topKNum + j] = pool[j]; + } + } + status = aclrtMalloc(&topK.deviceData, topK.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + status = aclrtMemcpy(topK.deviceData, topK.dataSize, topKData.data(), topKData.size() * sizeof(int32_t), ACL_MEMCPY_HOST_TO_DEVICE); + ASSERT_EQ(status, 0); + + atb::Tensor idxArr; + idxArr.desc.dtype = ACL_INT32; + idxArr.desc.format = ACL_FORMAT_ND; + idxArr.desc.shape.dimNum = 1; + idxArr.desc.shape.dims[0] = tokenNum * topKNum; + idxArr.dataSize = atb::Utils::GetTensorSize(idxArr); + std::vector idxArrData(atb::Utils::GetTensorNumel(idxArr), 0); + for (int32_t i = 0; i < idxArrData.size(); ++i) { + idxArrData[i] = i; + } + status = aclrtMalloc(&idxArr.deviceData, idxArr.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + status = aclrtMemcpy(idxArr.deviceData, idxArr.dataSize, idxArrData.data(), idxArrData.size() * sizeof(int32_t), ACL_MEMCPY_HOST_TO_DEVICE); + ASSERT_EQ(status, 0); + + atb::Tensor tokenIndex; + tokenIndex.desc.dtype = ACL_INT32; + tokenIndex.desc.format = ACL_FORMAT_ND; + tokenIndex.desc.shape.dimNum = 1; + tokenIndex.desc.shape.dims[0] = tokenNum * topKNum; + tokenIndex.dataSize = atb::Utils::GetTensorSize(tokenIndex); + status = aclrtMalloc(&tokenIndex.deviceData, tokenIndex.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + atb::Tensor cumSum; + cumSum.desc.dtype = ACL_INT32; + cumSum.desc.format = ACL_FORMAT_ND; + cumSum.desc.shape.dimNum = 1; + cumSum.desc.shape.dims[0] = expertNum; + cumSum.dataSize = atb::Utils::GetTensorSize(cumSum); + status = aclrtMalloc(&cumSum.deviceData, cumSum.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + atb::Tensor originalIndex; + originalIndex.desc.dtype = ACL_INT32; + originalIndex.desc.format = ACL_FORMAT_ND; + originalIndex.desc.shape.dimNum = 1; + originalIndex.desc.shape.dims[0] = tokenNum * topKNum; + originalIndex.dataSize = atb::Utils::GetTensorSize(originalIndex); + status = aclrtMalloc(&originalIndex.deviceData, originalIndex.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + // 此场景下不输出validIndex; + + atb::VariantPack variantPack; + variantPack.inTensors = {topK, idxArr}; + variantPack.outTensors = {tokenIndex, cumSum, originalIndex}; + + uint64_t workspaceSize = 0; + st = op->Setup(variantPack, workspaceSize, context); + ASSERT_EQ(st, 0); + void *workspace = nullptr; + status = aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + //填充脏数据 + status = aclrtMemset(workspace, workspaceSize, 0xFF, workspaceSize); + ASSERT_EQ(status, 0); + + st = op->Execute(variantPack, (uint8_t*)workspace, workspaceSize, context); + ASSERT_EQ(st, 0); + status = aclrtSynchronizeStream(stream); + ASSERT_EQ(status, 0); + + status = aclrtDestroyStream(stream); + ASSERT_EQ(status, 0); + status = aclrtFree(workspace); + ASSERT_EQ(status, 0); + st = atb::DestroyOperation(op); + ASSERT_EQ(st, 0); + st = atb::DestroyContext(context); + ASSERT_EQ(st, 0); + + for (size_t i = 0; i < variantPack.inTensors.size(); ++i) { + atb::Tensor tensor = variantPack.inTensors.at(i); + status = aclrtFree(tensor.deviceData); + ASSERT_EQ(status, 0); + tensor.deviceData = nullptr; + tensor.dataSize = 0; + } + for (size_t i = 0; i < variantPack.outTensors.size(); ++i) { + atb::Tensor tensor = variantPack.outTensors.at(i); + status = aclrtFree(tensor.deviceData); + ASSERT_EQ(status, 0); + tensor.deviceData = nullptr; + tensor.dataSize = 0; + } + aclrtResetDevice(deviceId); +} + +TEST(TestGraphLaunchMode, GatingGraphMode) +{ + const uint32_t tokenNum = 512; + const uint32_t expertNum = 1024; + const uint32_t topKNum = 4; + + if (atb::GetSingleton().Is910A() || atb::GetSingleton().Is310B()) { + GTEST_SKIP() << "This test case does not support 910A and 310B"; + } + + uint32_t deviceId = 0; + aclError status = aclrtSetDevice(deviceId); + ASSERT_EQ(status, 0); + atb::Context *context = nullptr; + atb::Status st = atb::CreateContext(&context); + ASSERT_EQ(st, 0); + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + ASSERT_EQ(status, 0); + context->SetExecuteStream(stream); + context->SetLaunchMode(atb::GRAPH_LAUNCH_MODE); + atb::infer::GatingParam param; + param.topkExpertNum = topKNum; + param.cumSumNum = 1024; + param.cumSumInt64 = false; + param.deviceExpert = std::vector(); + atb::Operation * op = nullptr; + st = atb::CreateOperation(param, &op); + ASSERT_EQ(st, 0); + + atb::Tensor topK; + topK.desc.dtype = ACL_INT32; + topK.desc.format = ACL_FORMAT_ND; + topK.desc.shape.dimNum = 1; + topK.desc.shape.dims[0] = tokenNum * topKNum; + topK.dataSize = atb::Utils::GetTensorSize(topK); + std::vector topKData(atb::Utils::GetTensorNumel(topK), 0); + std::random_device rd; + std::mt19937 gen(rd()); + for (int i = 0; i < tokenNum; ++i) { + std::vector pool(expertNum); + std::iota(pool.begin(), pool.end(), 0); + std::shuffle(pool.begin(), pool.end(), gen); + for (int j = 0; j < topKNum; ++j) { + topKData[i * topKNum + j] = pool[j]; + } + } + status = aclrtMalloc(&topK.deviceData, topK.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + status = aclrtMemcpy(topK.deviceData, topK.dataSize, topKData.data(), topKData.size() * sizeof(int32_t), ACL_MEMCPY_HOST_TO_DEVICE); + ASSERT_EQ(status, 0); + + atb::Tensor idxArr; + idxArr.desc.dtype = ACL_INT32; + idxArr.desc.format = ACL_FORMAT_ND; + idxArr.desc.shape.dimNum = 1; + idxArr.desc.shape.dims[0] = tokenNum * topKNum; + idxArr.dataSize = atb::Utils::GetTensorSize(idxArr); + std::vector idxArrData(atb::Utils::GetTensorNumel(idxArr), 0); + for (int32_t i = 0; i < idxArrData.size(); ++i) { + idxArrData[i] = i; + } + status = aclrtMalloc(&idxArr.deviceData, idxArr.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + status = aclrtMemcpy(idxArr.deviceData, idxArr.dataSize, idxArrData.data(), idxArrData.size() * sizeof(int32_t), ACL_MEMCPY_HOST_TO_DEVICE); + ASSERT_EQ(status, 0); + + atb::Tensor tokenIndex; + tokenIndex.desc.dtype = ACL_INT32; + tokenIndex.desc.format = ACL_FORMAT_ND; + tokenIndex.desc.shape.dimNum = 1; + tokenIndex.desc.shape.dims[0] = tokenNum * topKNum; + tokenIndex.dataSize = atb::Utils::GetTensorSize(tokenIndex); + status = aclrtMalloc(&tokenIndex.deviceData, tokenIndex.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + atb::Tensor cumSum; + cumSum.desc.dtype = ACL_INT32; + cumSum.desc.format = ACL_FORMAT_ND; + cumSum.desc.shape.dimNum = 1; + cumSum.desc.shape.dims[0] = expertNum; + cumSum.dataSize = atb::Utils::GetTensorSize(cumSum); + status = aclrtMalloc(&cumSum.deviceData, cumSum.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + atb::Tensor originalIndex; + originalIndex.desc.dtype = ACL_INT32; + originalIndex.desc.format = ACL_FORMAT_ND; + originalIndex.desc.shape.dimNum = 1; + originalIndex.desc.shape.dims[0] = tokenNum * topKNum; + originalIndex.dataSize = atb::Utils::GetTensorSize(originalIndex); + status = aclrtMalloc(&originalIndex.deviceData, originalIndex.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + ASSERT_EQ(status, 0); + + // 此场景下不输出validIndex; + + atb::VariantPack variantPack; + variantPack.inTensors = {topK, idxArr}; + variantPack.outTensors = {tokenIndex, cumSum, originalIndex}; + std::vector workspaces; + + for (size_t i = 0; i < 10; ++i) { + std::cout << "time: " << i + 1 << std::endl; + uint64_t workspaceSize = 0; + st = op->Setup(variantPack, workspaceSize, context); + ASSERT_EQ(st, 0); + void *workspace = nullptr; + status = aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + std::cout << "workspace: " << workspace << " workspaceSize: " << workspaceSize << std::endl; + ASSERT_EQ(status, 0); + workspaces.push_back(workspace); + + //填充脏数据 + status = aclrtMemset(workspace, workspaceSize, 0xFF, workspaceSize); + ASSERT_EQ(status, 0); + + st = op->Execute(variantPack, (uint8_t*)workspace, workspaceSize, context); + ASSERT_EQ(st, 0); + status = aclrtSynchronizeStream(stream); + ASSERT_EQ(status, 0); + } + + for (void *workspace : workspaces) { + status = aclrtFree(workspace); + ASSERT_EQ(status, 0); + } + status = aclrtDestroyStream(stream); + ASSERT_EQ(status, 0); + st = atb::DestroyOperation(op); + ASSERT_EQ(st, 0); + st = atb::DestroyContext(context); + ASSERT_EQ(st, 0); + for (size_t i = 0; i < variantPack.inTensors.size(); ++i) { atb::Tensor tensor = variantPack.inTensors.at(i); status = aclrtFree(tensor.deviceData); -- Gitee