diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index 960634a676855cf583cac8d4d37f6255f35fa2e6..0be8ca661dfebe5f0e0e075f72e3de991f4f557e 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -823,6 +823,11 @@ int32_t shmem_finalize(void) } shm::g_smem_handle = nullptr; } + + shm::g_state.is_shmem_created = false; + shm::g_state.is_shmem_initialized = false; + shm::g_state.is_shmem_initialized = false; + smem_shm_uninit(0); smem_uninit(); return SHMEM_SUCCESS; diff --git a/tests/unittest/host/init/init_host_test.cpp b/tests/unittest/host/init/init_host_test.cpp index c4b59298c8726bce33f21910bfae698f8320d5fb..ae0aa8317de73a148e09b19b38b702796fea3d2b 100644 --- a/tests/unittest/host/init/init_host_test.cpp +++ b/tests/unittest/host/init/init_host_test.cpp @@ -47,9 +47,6 @@ void test_shmem_init(int rank_id, int n_ranks, uint64_t local_mem_size) EXPECT_EQ(status, SHMEM_SUCCESS); EXPECT_EQ(aclrtResetDevice(device_id), 0); EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } } void test_shmem_init_attr(int rank_id, int n_ranks, uint64_t local_mem_size) @@ -81,9 +78,6 @@ void test_shmem_init_attr(int rank_id, int n_ranks, uint64_t local_mem_size) EXPECT_EQ(status, SHMEM_SUCCESS); EXPECT_EQ(aclrtResetDevice(device_id), 0); EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } } void test_shmem_init_invalid_rank_id(int rank_id, int n_ranks, uint64_t local_mem_size) @@ -102,9 +96,6 @@ void test_shmem_init_invalid_rank_id(int rank_id, int n_ranks, uint64_t local_me EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); EXPECT_EQ(aclrtResetDevice(device_id), 0); EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } } void test_shmem_init_invalid_n_ranks(int rank_id, int n_ranks, uint64_t local_mem_size) @@ -127,9 +118,6 @@ void test_shmem_init_invalid_n_ranks(int rank_id, int n_ranks, uint64_t local_me EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); EXPECT_EQ(aclrtResetDevice(device_id), 0); EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } } void test_shmem_init_rank_id_over_size(int rank_id, int n_ranks, uint64_t local_mem_size) @@ -147,9 +135,6 @@ void test_shmem_init_rank_id_over_size(int rank_id, int n_ranks, uint64_t local_ EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); EXPECT_EQ(aclrtResetDevice(device_id), 0); EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } } void test_shmem_init_zero_mem(int rank_id, int n_ranks, uint64_t local_mem_size) @@ -168,9 +153,6 @@ void test_shmem_init_zero_mem(int rank_id, int n_ranks, uint64_t local_mem_size) EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); EXPECT_EQ(aclrtResetDevice(device_id), 0); EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } } void test_shmem_init_invalid_mem(int rank_id, int n_ranks, uint64_t local_mem_size) @@ -189,9 +171,6 @@ void test_shmem_init_invalid_mem(int rank_id, int n_ranks, uint64_t local_mem_si EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); EXPECT_EQ(aclrtResetDevice(device_id), 0); EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } } void test_shmem_init_set_config(int rank_id, int n_ranks, uint64_t local_mem_size) @@ -228,9 +207,6 @@ void test_shmem_init_set_config(int rank_id, int n_ranks, uint64_t local_mem_siz EXPECT_EQ(status, SHMEM_SUCCESS); EXPECT_EQ(aclrtResetDevice(device_id), 0); EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } } void test_shmem_global_exit(int rank_id, int n_ranks, uint64_t local_mem_size) @@ -268,9 +244,132 @@ void test_shmem_global_exit(int rank_id, int n_ranks, uint64_t local_mem_size) shmem_global_exit(0); EXPECT_EQ(aclrtResetDevice(device_id), 0); EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); +} + +void test_shmem_init_status_after_finalize(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; + int status = SHMEM_SUCCESS; + + EXPECT_EQ(aclInit(nullptr), 0); + EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + + shmem_init_attr_t *attributes = nullptr; + EXPECT_EQ(shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes), SHMEM_SUCCESS); + shmem_set_conf_store_tls(false, nullptr, 0); + + status = shmem_init_attr(attributes); + EXPECT_EQ(status, SHMEM_SUCCESS); + + status = shmem_init_status(); + EXPECT_EQ(status, SHMEM_STATUS_IS_INITIALIZED); + + status = shmem_finalize(); + EXPECT_EQ(status, SHMEM_SUCCESS); + + status = shmem_init_status(); + EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); + + EXPECT_EQ(aclrtResetDevice(device_id), 0); + EXPECT_EQ(aclFinalize(), 0); +} + +void test_shmem_init_attr_null_attr(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; + int status = SHMEM_SUCCESS; + + EXPECT_EQ(aclInit(nullptr), 0); + EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + + status = shmem_init_attr(nullptr); + EXPECT_EQ(status, SHMEM_INVALID_PARAM); + EXPECT_EQ(shmem_init_status(), SHMEM_STATUS_NOT_INITIALIZED); + + EXPECT_EQ(aclrtResetDevice(device_id), 0); + EXPECT_EQ(aclFinalize(), 0); +} + +void test_shmem_init_attr_repeated_init(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; + int status = SHMEM_SUCCESS; + + EXPECT_EQ(aclInit(nullptr), 0); + EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + + shmem_init_attr_t *attributes = nullptr; + ASSERT_EQ(shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes), SHMEM_SUCCESS); + shmem_set_conf_store_tls(false, nullptr, 0); + + status = shmem_init_attr(attributes); + EXPECT_EQ(status, SHMEM_SUCCESS); + EXPECT_EQ(shmem_init_status(), SHMEM_STATUS_IS_INITIALIZED); + + int32_t status2 = shmem_init_attr(attributes); + EXPECT_NE(status2, SHMEM_SUCCESS); + + EXPECT_EQ(shmem_finalize(), SHMEM_SUCCESS); + EXPECT_EQ(aclrtResetDevice(device_id), 0); + EXPECT_EQ(aclFinalize(), 0); +} + +void test_shmem_finalize_exception_paths(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; + int status = SHMEM_SUCCESS; + + EXPECT_EQ(aclInit(nullptr), 0); + EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + + status = shmem_finalize(); + EXPECT_EQ(status, SHMEM_SUCCESS); + + shmem_init_attr_t *attributes = nullptr; + ASSERT_EQ(shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes), SHMEM_SUCCESS); + shmem_set_conf_store_tls(false, nullptr, 0); + + status = shmem_init_attr(attributes); + EXPECT_EQ(status, SHMEM_SUCCESS); + EXPECT_EQ(shmem_init_status(), SHMEM_STATUS_IS_INITIALIZED); + + status = shmem_finalize(); + EXPECT_EQ(status, SHMEM_SUCCESS); + EXPECT_EQ(shmem_init_status(), SHMEM_STATUS_NOT_INITIALIZED); + + status = shmem_finalize(); + EXPECT_EQ(status, SHMEM_SUCCESS); + + EXPECT_EQ(aclrtResetDevice(device_id), 0); + EXPECT_EQ(aclFinalize(), 0); +} + +void test_shmem_init_finalize_loop(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; + int status = SHMEM_SUCCESS; + + EXPECT_EQ(aclInit(nullptr), 0); + EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + + const int loop_time = 3; + for (int i = 0; i < loop_time; ++i) { + shmem_init_attr_t *attributes = nullptr; + ASSERT_EQ(shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes), SHMEM_SUCCESS); + shmem_set_conf_store_tls(false, nullptr, 0); + + status = shmem_init_attr(attributes); + EXPECT_EQ(status, SHMEM_SUCCESS); + EXPECT_EQ(shmem_init_status(), SHMEM_STATUS_IS_INITIALIZED); + + status = shmem_finalize(); + sleep(5); + EXPECT_EQ(status, SHMEM_SUCCESS); + EXPECT_EQ(shmem_init_status(), SHMEM_STATUS_NOT_INITIALIZED); } + + EXPECT_EQ(aclrtResetDevice(device_id), 0); + EXPECT_EQ(aclFinalize(), 0); } TEST(TestInitAPI, TestShmemInit) @@ -372,6 +471,41 @@ TEST(TestInitAPI, TestShmemGlobalExit) test_mutil_task(test_shmem_global_exit, local_mem_size, process_count); } +TEST(TestInitAPI, TestShmemInitStatusAfterFinalize) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + test_mutil_task(test_shmem_init_status_after_finalize, local_mem_size, process_count); +} + +TEST(TestInitAPI, TestShmemInitAttrNullAttr) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + test_mutil_task(test_shmem_init_attr_null_attr, local_mem_size, process_count); +} + +TEST(TestInitAPI, TestShmemInitAttrRepeatedInit) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + test_mutil_task(test_shmem_init_attr_repeated_init, local_mem_size, process_count); +} + +TEST(TestInitAPI, TestShmemFinalizeExceptionPaths) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + test_mutil_task(test_shmem_finalize_exception_paths, local_mem_size, process_count); +} + +TEST(TestInitAPI, TestShmemInitFinalizeLoop) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + test_mutil_task(test_shmem_init_finalize_loop, local_mem_size, process_count); +} + TEST(TestInitAPI, TestShmemSetLogLevel) { auto ret = shmem_set_log_level(shm::DEBUG_LEVEL); @@ -401,6 +535,151 @@ TEST(TestInitAPI, TestShmemSetExternLogger) EXPECT_EQ(ret, 0); } +TEST(TestInitAPI, TestShmemSetAttrBasic) +{ + shmem_init_attr_t *attributes = nullptr; + int my_rank = 0; + int n_ranks = 4; + uint64_t local_mem_size = 1024UL * 1024UL; + const char *ip_port = test_global_ipport; + + int32_t ret = shmem_set_attr(my_rank, n_ranks, local_mem_size, ip_port, &attributes); + EXPECT_EQ(ret, SHMEM_SUCCESS); + ASSERT_NE(attributes, nullptr); + EXPECT_EQ(attributes, &shm::g_attr); + + EXPECT_EQ(attributes->my_rank, my_rank); + EXPECT_EQ(attributes->n_ranks, n_ranks); + EXPECT_EQ(attributes->local_mem_size, local_mem_size); + EXPECT_STREQ(attributes->ip_port, ip_port); + + int expect_version = (1 << 16) + static_cast(sizeof(shmem_init_attr_t)); + EXPECT_EQ(attributes->option_attr.version, expect_version); + EXPECT_EQ(attributes->option_attr.data_op_engine_type, SHMEM_DATA_OP_MTE); + EXPECT_EQ(attributes->option_attr.shm_init_timeout, 120U); + EXPECT_EQ(attributes->option_attr.shm_create_timeout, 120U); + EXPECT_EQ(attributes->option_attr.control_operation_timeout, 120U); +} + +TEST(TestInitAPI, TestShmemSetAttrIpPortNull) +{ + shmem_init_attr_t *attributes = nullptr; + int my_rank = 0; + int n_ranks = 4; + uint64_t local_mem_size = 1024UL * 1024UL; + + int32_t ret = shmem_set_attr(my_rank, n_ranks, local_mem_size, nullptr, &attributes); + EXPECT_EQ(ret, SHMEM_SUCCESS); + ASSERT_NE(attributes, nullptr); + EXPECT_EQ(attributes->my_rank, my_rank); + EXPECT_EQ(attributes->n_ranks, n_ranks); + EXPECT_EQ(attributes->local_mem_size, local_mem_size); + EXPECT_STREQ(attributes->ip_port, ""); +} + +TEST(TestInitAPI, TestShmemSetAttrIpPortEmpty) +{ + shmem_init_attr_t *attributes = nullptr; + int my_rank = 0; + int n_ranks = 4; + uint64_t local_mem_size = 1024UL * 1024UL; + const char *ip_port = ""; + + int32_t ret = shmem_set_attr(my_rank, n_ranks, local_mem_size, ip_port, &attributes); + EXPECT_EQ(ret, SHMEM_INVALID_VALUE); + ASSERT_NE(attributes, nullptr); + EXPECT_STREQ(attributes->ip_port, ""); +} + +TEST(TestInitAPI, TestShmemSetDataOpEngineTypeBasicAndOverwrite) +{ + shmem_init_attr_t *attributes = nullptr; + int my_rank = 0; + int n_ranks = 4; + uint64_t local_mem_size = 1024UL * 1024UL; + + ASSERT_EQ(shmem_set_attr(my_rank, n_ranks, local_mem_size, test_global_ipport, &attributes), SHMEM_SUCCESS); + ASSERT_NE(attributes, nullptr); + EXPECT_EQ(attributes->option_attr.data_op_engine_type, SHMEM_DATA_OP_MTE); + + int32_t ret = shmem_set_data_op_engine_type(attributes, SHMEM_DATA_OP_SDMA); + EXPECT_EQ(ret, SHMEM_SUCCESS); + EXPECT_EQ(attributes->option_attr.data_op_engine_type, SHMEM_DATA_OP_SDMA); + + ASSERT_EQ(shmem_set_attr(my_rank, n_ranks, local_mem_size, test_global_ipport, &attributes), SHMEM_SUCCESS); + EXPECT_EQ(attributes->option_attr.data_op_engine_type, SHMEM_DATA_OP_MTE); +} + +TEST(TestInitAPI, TestShmemSetDataOpEngineTypeNullAttr) +{ + int32_t ret = shmem_set_data_op_engine_type(nullptr, SHMEM_DATA_OP_MTE); + EXPECT_EQ(ret, SHMEM_INVALID_PARAM); +} + +TEST(TestInitAPI, TestShmemSetDataOpEngineTypeInvalidValue) +{ + shmem_init_attr_t attr{}; + auto invalid_value = static_cast(0); + + int32_t ret = shmem_set_data_op_engine_type(&attr, invalid_value); + EXPECT_EQ(ret, SHMEM_SUCCESS); + EXPECT_EQ(attr.option_attr.data_op_engine_type, invalid_value); +} + +TEST(TestInitAPI, TestShmemSetTimeoutBoundaryValues) +{ + shmem_init_attr_t attr{}; + + int32_t ret = shmem_set_timeout(&attr, 0U); + EXPECT_EQ(ret, SHMEM_SUCCESS); + EXPECT_EQ(attr.option_attr.shm_init_timeout, 0U); + EXPECT_EQ(attr.option_attr.shm_create_timeout, 0U); + EXPECT_EQ(attr.option_attr.control_operation_timeout, 0U); + + ret = shmem_set_timeout(&attr, 1U); + EXPECT_EQ(ret, SHMEM_SUCCESS); + EXPECT_EQ(attr.option_attr.shm_init_timeout, 1U); + EXPECT_EQ(attr.option_attr.shm_create_timeout, 1U); + EXPECT_EQ(attr.option_attr.control_operation_timeout, 1U); + + uint32_t max_val = 0xFFFFFFFFu; + ret = shmem_set_timeout(&attr, max_val); + EXPECT_EQ(ret, SHMEM_SUCCESS); + EXPECT_EQ(attr.option_attr.shm_init_timeout, max_val); + EXPECT_EQ(attr.option_attr.shm_create_timeout, max_val); + EXPECT_EQ(attr.option_attr.control_operation_timeout, max_val); +} + +TEST(TestInitAPI, TestShmemSetTimeoutNullAttr) +{ + int32_t ret = shmem_set_timeout(nullptr, 1U); + EXPECT_EQ(ret, SHMEM_INVALID_PARAM); +} + +TEST(TestInitAPI, TestShmemSetTimeoutOverwriteBySetAttr) +{ + shmem_init_attr_t *attributes = nullptr; + int my_rank = 0; + int n_ranks = 4; + uint64_t local_mem_size = 1024UL * 1024UL; + + ASSERT_EQ(shmem_set_attr(my_rank, n_ranks, local_mem_size, test_global_ipport, &attributes), SHMEM_SUCCESS); + ASSERT_NE(attributes, nullptr); + + EXPECT_EQ(attributes->option_attr.shm_init_timeout, 120U); + EXPECT_EQ(attributes->option_attr.shm_create_timeout, 120U); + EXPECT_EQ(attributes->option_attr.control_operation_timeout, 120U); + + int32_t ret = shmem_set_timeout(attributes, shm::timeout); + EXPECT_EQ(ret, SHMEM_SUCCESS); + EXPECT_EQ(attributes->option_attr.shm_init_timeout, shm::timeout); + EXPECT_EQ(attributes->option_attr.shm_create_timeout, shm::timeout); + EXPECT_EQ(attributes->option_attr.control_operation_timeout, shm::timeout); + + ASSERT_EQ(shmem_set_attr(my_rank, n_ranks, local_mem_size, test_global_ipport, &attributes), SHMEM_SUCCESS); + EXPECT_EQ(attributes->option_attr.shm_init_timeout, 120U); + EXPECT_EQ(attributes->option_attr.shm_create_timeout, 120U); + EXPECT_EQ(attributes->option_attr.control_operation_timeout, 120U); TEST(TestInitAPI, TestShmemGetUniqueId) { const char *ipInfo = std::getenv("SHMEM_UID_SOCK_IFNAM"); diff --git a/tests/unittest/host/main_test.cpp b/tests/unittest/host/main_test.cpp index a91dbaba20a880f3c19bf83266054db255bee07c..8b14b1522c367c756e76b510c49e5e72f10f5b8a 100644 --- a/tests/unittest/host/main_test.cpp +++ b/tests/unittest/host/main_test.cpp @@ -88,10 +88,16 @@ void test_mutil_task(std::function func, uint64_t loca std::cout << "fork failed ! " << pids[i] << std::endl; } else if (pids[i] == 0) { func(i + test_first_rank, test_global_ranks, local_mem_size); - exit(0); + if (::testing::Test::HasFailure()) { + _exit(1); + } + _exit(0); } } for (int i = 0; i < process_count; ++i) { + if (pids[i] <= 0) { + continue; + } waitpid(pids[i], &status[i], 0); if (WIFEXITED(status[i]) && WEXITSTATUS(status[i]) != 0) { FAIL(); diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index 43256604498b290cccfdbd00ced368390673550b..8c1610c8c8a957caa973fe53f45fc6d6f8d3fb8b 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -103,7 +103,7 @@ TEST_F(ShareMemoryManagerTest, allocate_large_memory_failed) int32_t device_id = rank_id % test_gnpu_num + test_first_npu; aclrtStream stream; test_init(rank_id, n_ranks, local_mem_size, &stream); - auto ptr = shmem_malloc(heap_memory_size + 1UL); + auto ptr = shmem_malloc(heap_memory_size + SHMEM_EXTRA_SIZE + 1UL); EXPECT_EQ(nullptr, ptr); test_finalize(stream, device_id); }, @@ -186,7 +186,7 @@ TEST_F(ShareMemoryManagerTest, calloc_large_memory_failed) aclrtStream stream; test_init(rank_id, n_ranks, local_mem_size, &stream); const size_t nmemb = 16; - auto ptr = shmem_calloc(nmemb, heap_memory_size / nmemb + 1UL); + auto ptr = shmem_calloc(nmemb, (heap_memory_size + SHMEM_EXTRA_SIZE) / nmemb + 1UL); EXPECT_EQ(nullptr, ptr); test_finalize(stream, device_id); }, @@ -257,7 +257,7 @@ TEST_F(ShareMemoryManagerTest, align_large_memory_failed) aclrtStream stream; test_init(rank_id, n_ranks, local_mem_size, &stream); const size_t alignment = 16; - auto ptr = shmem_align(alignment, heap_memory_size + 1UL); + auto ptr = shmem_align(alignment, heap_memory_size + SHMEM_EXTRA_SIZE + 1UL); EXPECT_EQ(nullptr, ptr); test_finalize(stream, device_id); },